pytorch  1.8.2
About: PyTorch provides Tensor computation (like NumPy) with strong GPU acceleration and Deep Neural Networks (in Python) built on a tape-based autograd system. LTS (Long Term Support) release.
  Fossies Dox: pytorch-1.8.2.tar.gz  ("unofficial" and yet experimental doxygen-generated source code documentation)  

cc_bmm_bg_op.h
Go to the documentation of this file.
1#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
2#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
3
6#include "caffe2/core/types.h"
7#include "caffe2/utils/math.h"
8
9namespace caffe2 {
10
11using T = float;
12using TInd = int;
14
15template <class Context>
16class ConcatBatchMatMulBatchGatherOp final : public Operator<Context> {
17 public:
19
20 ConcatBatchMatMulBatchGatherOp(const OperatorDef& operator_def, Workspace* ws)
21 : Operator<Context>(operator_def, ws) {}
22
23 bool RunOnDevice() override;
24
25 protected:
26 int axis_ = 1;
27 int add_axis_ = 1;
28
29 bool trans_a_ = 0;
30 bool trans_b_ = 1;
31 bool broadcast_ = 0;
32};
33
34template <class Context>
36 auto& indices = Input(0);
37 auto& input_zero = Input(1);
38 int adj_size = input_zero.dim() + 1;
39 int canonical_axis = 1;
40 CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
41 for (int i = 2; i < InputSize(); ++i) {
43 Input(i).dtype() == input_zero.dtype(),
44 "All inputs must have the same type, expected: ",
45 input_zero.dtype().name(),
46 " but got: ",
47 Input(i).dtype().name(),
48 " for input: ",
49 i);
50 }
51
52 int before = 1, after = 1;
53 for (int i = 0; i < input_zero.dim(); ++i) {
54 int dim = input_zero.dim32(i);
55 if (i < canonical_axis) {
56 before *= dim;
57 } else { // i > canonical_axis || i == canonical_axis && add_axis_
58 after *= dim;
59 }
60 // check the input dims are compatible.
61 for (int j = 2; j < InputSize(); ++j) {
62 int dim_j = Input(j).dim32(i);
64 dim == dim_j,
65 "Expect dimension = ",
66 dim,
67 " got ",
68 dim_j,
69 " at axis = ",
70 i,
71 " for input: ",
72 j,
73 ". The input tensors can only have different dimensions "
74 "when arg 'add_axis' = 0 and along the axis = ",
76 " <",
77 input_zero.sizes(),
78 "> vs <",
79 Input(j).sizes(),
80 ">.");
81 }
82 }
83
84 auto ndata = InputSize() - 1;
85 auto batch_size = before;
86 auto embed_size = after;
87 auto gather_size = indices.sizes()[0];
88
89 vector<int64_t> output_dims;
90 output_dims.push_back(batch_size);
91 output_dims.insert(
92 output_dims.begin() + 1, indices.sizes().begin(), indices.sizes().end());
93 auto* output = Output(0, output_dims, at::dtype<T>());
94 // std::stringstream ss;
95 // ss << "[";
96 // for(int i = 0; i < output_dims.size(); i++) ss << output_dims[i];
97 // ss << "]";
98 // LOG(INFO) << "output size: " << ss.str();
99
100 auto* output_data = output->template mutable_data<T>();
101 auto* indices_data = indices.template data<TInd>();
102#pragma omp parallel
103 {
104 std::vector<T> scratch_input(ndata * embed_size);
105 std::vector<T> scratch_output(ndata * ndata);
106
107#pragma omp for
108 for (int b = 0; b < batch_size; ++b) {
109 // concat input to scratch
110 for (int i = 1; i < InputSize(); ++i) {
111 auto* input_data = Input(i).template data<T>();
112 memcpy(
113 &scratch_input[(i - 1) * embed_size],
114 input_data + b * embed_size,
115 embed_size * Input(i).itemsize());
116 }
117 // call mkl gemm
118 math::Gemm<T, Context, Engine>(
121 ndata,
122 ndata,
123 embed_size,
124 1,
125 &scratch_input[0],
126 &scratch_input[0],
127 0,
128 &scratch_output[0],
129 &context_);
130 // do gather
131
132 int64_t output_offset = b * gather_size;
133 for (int i = 0; i < gather_size; i++) {
134 output_data[output_offset + i] = scratch_output[indices_data[i]];
135 }
136 }
137 }
138 return true;
139}
140
141} // namespace caffe2
142
143#endif // CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#define CAFFE_ENFORCE_LT(x, y,...)
Definition: Logging.h:259
@ CblasNoTrans
Definition: cblas.h:14
@ CblasTrans
Definition: cblas.h:14
ConcatBatchMatMulBatchGatherOp(const OperatorDef &operator_def, Workspace *ws)
Definition: cc_bmm_bg_op.h:20
Workspace is a class that holds all the related objects created during runtime: (1) all blobs,...
Definition: workspace.h:47
CPUContext * context_
std::string name
Copyright (c) 2016-present, Facebook, Inc.
Definition: blob.h:13
const auto canonical_axis
float T
Definition: cc_bmm_bg_op.h:11
SparseLengths8BitsRowwiseOp< CPUContext, 0, 1 >::LENGTHS uint8 tensor obtained with Vector with the same sum of elements as the first dimension of DATA Input(3, "scale_bias", "Matrix of floats, each row r_i of which stores a pair " "s_i, b_i -- scale and bias for i-th row") .Output(0
for each weights are accessed by indices[0..L-1]
vector< int > output_dims
true SparseLengthsFused4BitRowwiseFakeFP16Op< CPUContext, true >::WEIGHTS uint8 tensor obtained with Vector with the same sum of elements as the first dimension of DATA output
INT_MAX float
d int long tensor contains the length in each of the output N dim Tensor where dim(1) is the max length" "
*type depends on dtype
int TInd
Definition: cc_bmm_bg_op.h:12
required base learning rate default used only for inv policy type default sampling rate on iterations default True in alter policy int64_t
SparseLengths8BitsRowwiseOp< CPUContext, 1, 1 >::LENGTHS uint8 tensor obtained with Integer vector containing indices of the first dimension of DATA for the slices that are being aggregated Matrix of each row r_i of which stores a pair b_i scale and bias for i th row Output(0, "output", "output")
INT_MAX int
stride sizes
Definition: lp_pool_op.cc:233
CAFFE_ENFORCE(dims.front() >=0, "Dimension ids must be non-negative.")