|
9 | 9 | #include <tvm/relay/expr.h> |
10 | 10 | #include <tvm/relay/op.h> |
11 | 11 | #include <tvm/relay/op_attr_types.h> |
| 12 | +#include <topi/elemwise.h> |
12 | 13 |
|
13 | 14 | #include "../type_relations.h" |
14 | 15 | #include "../../pass/alter_op_layout.h" |
@@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device") |
37 | 38 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", |
38 | 39 | ElemwiseArbitraryLayout); |
39 | 40 |
|
| 41 | +Expr StopFusion(Expr data) { |
| 42 | + static const Op& op = Op::Get("annotation.stop_fusion"); |
| 43 | + return CallNode::make(op, {data}, Attrs{}, {}); |
| 44 | +} |
| 45 | + |
| 46 | +TVM_REGISTER_API("relay.op.annotation._make.stop_fusion") |
| 47 | +.set_body_typed<Expr(Expr)>([](Expr data) { |
| 48 | + return StopFusion(data); |
| 49 | +}); |
| 50 | + |
| 51 | +RELAY_REGISTER_OP("annotation.stop_fusion") |
| 52 | +.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code" |
| 53 | +TVM_ADD_FILELINE) |
| 54 | +.set_num_inputs(1) |
| 55 | +.add_argument("data", "Tensor", "The input data.") |
| 56 | +.add_type_rel("Identity", IdentityRel) |
| 57 | +.set_support_level(10) |
| 58 | +.set_attr<TOpPattern>("TOpPattern", kOpaque) |
| 59 | +.set_attr<TOpIsStateful>("TOpIsStateful", false) |
| 60 | +.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) |
| 61 | +.set_attr<FTVMCompute>("FTVMCompute", |
| 62 | + [](const Attrs& attrs, const Array<Tensor>& inputs, |
| 63 | + const Type& out_dtype, const Target& target) -> Array<Tensor> { |
| 64 | + return {topi::identity(inputs[0])}; |
| 65 | + }); |
40 | 66 |
|
41 | 67 | } // namespace relay |
42 | 68 | } // namespace tvm |
0 commit comments