1#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
2#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
15template <
class Context>
34template <
class Context>
37 auto& input_zero =
Input(1);
38 int adj_size = input_zero.dim() + 1;
41 for (
int i = 2; i < InputSize(); ++i) {
44 "All inputs must have the same type, expected: ",
45 input_zero.dtype().name(),
52 int before = 1, after = 1;
53 for (
int i = 0; i < input_zero.dim(); ++i) {
54 int dim = input_zero.dim32(i);
61 for (
int j = 2; j < InputSize(); ++j) {
62 int dim_j =
Input(j).dim32(i);
65 "Expect dimension = ",
73 ". The input tensors can only have different dimensions "
74 "when arg 'add_axis' = 0 and along the axis = ",
84 auto ndata = InputSize() - 1;
86 auto embed_size = after;
87 auto gather_size =
indices.sizes()[0];
100 auto* output_data =
output->template mutable_data<T>();
101 auto* indices_data =
indices.template data<TInd>();
104 std::vector<T> scratch_input(ndata * embed_size);
105 std::vector<T> scratch_output(ndata * ndata);
110 for (
int i = 1; i < InputSize(); ++i) {
111 auto* input_data =
Input(i).template data<T>();
113 &scratch_input[(i - 1) * embed_size],
114 input_data +
b * embed_size,
115 embed_size *
Input(i).itemsize());
118 math::Gemm<T, Context, Engine>(
132 int64_t output_offset =
b * gather_size;
133 for (
int i = 0; i < gather_size; i++) {
134 output_data[output_offset + i] = scratch_output[indices_data[i]];
#define CAFFE_ENFORCE_LT(x, y,...)
bool RunOnDevice() override
ConcatBatchMatMulBatchGatherOp(const OperatorDef &operator_def, Workspace *ws)
USE_OPERATOR_CONTEXT_FUNCTIONS
Workspace is a class that holds all the related objects created during runtime: (1) all blobs,...
Copyright (c) 2016-present, Facebook, Inc.
const auto canonical_axis
SparseLengths8BitsRowwiseOp< CPUContext, 0, 1 >::LENGTHS uint8 tensor obtained with Vector with the same sum of elements as the first dimension of DATA Input(3, "scale_bias", "Matrix of floats, each row r_i of which stores a pair " "s_i, b_i -- scale and bias for i-th row") .Output(0
for each weights are accessed by indices[0..L-1]
vector< int > output_dims
true SparseLengthsFused4BitRowwiseFakeFP16Op< CPUContext, true >::WEIGHTS uint8 tensor obtained with Vector with the same sum of elements as the first dimension of DATA output
d int long tensor contains the length in each of the output N dim Tensor where dim(1) is the max length" "
required base learning rate default used only for inv policy type default sampling rate on iterations default True in alter policy int64_t
SparseLengths8BitsRowwiseOp< CPUContext, 1, 1 >::LENGTHS uint8 tensor obtained with Integer vector containing indices of the first dimension of DATA for the slices that are being aggregated Matrix of each row r_i of which stores a pair b_i scale and bias for i th row Output(0, "output", "output")
CAFFE_ENFORCE(dims.front() >=0, "Dimension ids must be non-negative.")