pytorch  1.8.2
About: PyTorch provides Tensor computation (like NumPy) with strong GPU acceleration and Deep Neural Networks (in Python) built on a tape-based autograd system. LTS (Long Term Support) release.
  Fossies Dox: pytorch-1.8.2.tar.gz  ("unofficial" and yet experimental doxygen-generated source code documentation)  

boolean_mask_ops.h
Go to the documentation of this file.
1#ifndef CAFFE2_OPERATORS_BOOLEAN_MASK_OPS_H_
2#define CAFFE2_OPERATORS_BOOLEAN_MASK_OPS_H_
3
8
9namespace caffe2 {
10
11template <class Context>
12class BooleanMaskOp final : public Operator<Context> {
13 public:
15 template <class... Args>
16 explicit BooleanMaskOp(Args&&... args)
17 : Operator<Context>(std::forward<Args>(args)...) {}
18
19 bool RunOnDevice() override;
20};
21
22template <class Context>
23class BooleanMaskOpGradient final : public Operator<Context> {
24 public:
26 BooleanMaskOpGradient(const OperatorDef& operator_def, Workspace* ws)
27 : Operator<Context>(operator_def, ws) {}
28
29 /* Calculating the gradient of the Boolean Mask operator
30 * requires access to the original mask that's passed in,
31 * and the gradient to backpropagate.
32 */
33 bool RunOnDevice() override {
34 return DispatchHelper<
36 call(this, Input(1));
37 }
38
39 template <typename T>
41};
42
43template <class Context>
44class SequenceMaskOp final : public Operator<Context> {
45 public:
47 explicit SequenceMaskOp(const OperatorDef& operator_def, Workspace* ws)
48 : Operator<Context>(operator_def, ws),
49 axis_(this->template GetSingleArgument<int>("axis", 1)),
50 radius_(this->template GetSingleArgument<int>("radius", 10)),
51 grad_(this->template GetSingleArgument<bool>("grad", false)),
52 fill_val_(this->template GetSingleArgument<float>(
53 "fill_val",
54 -1.0f * std::numeric_limits<float>::infinity())) {
55 // Mode argument is required
56 mode_ = GetArgument(operator_def, "mode").s();
57 // batch argument is optional, but if not given, we don't want a default val
58 if (HasArgument("batch")) {
59 batch_ = GetArgument(operator_def, "batch").i();
60 }
61
62 if (HasArgument("repeat_from_axis")) {
64 mode_ == "sequence",
65 "repeat_from_axis currently only supported in sequence mode.");
67 !HasArgument("batch"),
68 "repeat_from_axis and batch not currently supported together.");
70 this->template GetSingleArgument<int>("repeat_from_axis", -1);
71 }
72 }
73
74 bool RunOnDevice() override;
75
76 template <typename T>
78
79 private:
80 int axis_;
82 std::string mode_;
83 bool grad_;
84 float fill_val_;
85 int batch_;
87};
88
89} // namespace caffe2
90
91#endif
Args({2<< 5}) -> Args({2<< 8}) ->Args({2<< 12}) ->Args({2<< 14})
BooleanMaskOpGradient(const OperatorDef &operator_def, Workspace *ws)
BooleanMaskOp(Args &&... args)
bool RunOnDevice() override
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:99
T GetSingleArgument(const string &name, const T &default_value) const
Definition: operator.h:110
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator.
Definition: operator.h:836
bool RunOnDevice() override
SequenceMaskOp(const OperatorDef &operator_def, Workspace *ws)
Workspace is a class that holds all the related objects created during runtime: (1) all blobs,...
Definition: workspace.h:47
int call(int id)
void forward(int64_t offset)
Copyright (c) 2016-present, Facebook, Inc.
Definition: blob.h:13
d int long tensor contains the length in each of the output N dim Tensor where dim boolean false where packed_tensor is true otherwise Padding number in the packed segments Use true to pad infinity
C10_EXPORT const Argument & GetArgument(const OperatorDef &def, const string &name)
Definition: proto_utils.cc:522
INT_MAX Subnet with blob bindings Indices of corresponding outer workspace in List of blobs from the forward Do int out bool
Definition: do_op.cc:26
INT_MAX float
const ArgumentHelper args(def)
INT_MAX int
CAFFE_ENFORCE(dims.front() >=0, "Dimension ids must be non-negative.")
STL namespace.