diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 3eb41360578e..8b11bfde7fe7 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -29,3 +29,19 @@ def on_device(data, device): raise ValueError("device is expected to be the type of TVMContext or " "str, but received %s" % (type(device))) return _make.on_device(data, device) + + +def stop_fusion(data): + """Annotate an expression to prevent it being fused with previous expressions. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.stop_fusion(data) diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 4069512247de..5dbec5e58c1c 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include "../type_relations.h" #include "../../pass/alter_op_layout.h" @@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); +Expr StopFusion(Expr data) { + static const Op& op = Op::Get("annotation.stop_fusion"); + return CallNode::make(op, {data}, Attrs{}, {}); +} + +TVM_REGISTER_API("relay.op.annotation._make.stop_fusion") +.set_body_typed([](Expr data) { + return StopFusion(data); +}); + +RELAY_REGISTER_OP("annotation.stop_fusion") +.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code" +TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input data.") +.add_type_rel("Identity", IdentityRel) +.set_support_level(10) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 169aef3b6a4a..3227a70f3e7c 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -741,10 +741,14 @@ class FuseMutator : private ExprMutator { } // Transform calls. Expr VisitExpr_(const CallNode* call) { + static const Op& stop_fusion = Op::Get("annotation.stop_fusion"); if (call->op.as()) { // If it is a primitive op call // then we must have a group assignment for it already. CHECK(gmap_.count(call)); + if (call->op.same_as(stop_fusion)) { + return ExprMutator::VisitExpr(call->args[0]); + } auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 83a0bb9157ee..08fc017f41eb 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -329,6 +329,8 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); +Expr StopFusion(Expr data); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index f5c11d14745b..1d926a325b1a 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -220,9 +220,41 @@ def expected(dshape): print(zz.astext()) +def test_stop_fusion(): + def before(dshape): + x = relay.var("x", shape=dshape) + y = relay.add(x, relay.const(1, "float32")) + y = relay.annotation.stop_fusion(y) + z = relay.exp(y) + return relay.Function([x], z) + + def expected(dshape): + x = relay.var("p0", shape=dshape) + y = relay.add(x, relay.const(1, "float32")) + f1 = relay.Function([x], y) + + x = relay.var("p01", shape=dshape) + y = relay.exp(x) + f2 = relay.Function([x], y) + + x = relay.var("x", shape=dshape) + y = relay.Call(f1, [x]) + z = relay.Call(f2, [y]) + return relay.Function([x], z) + + dshape = (10, 20) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + z = relay.ir_pass.fuse_ops(z) + z = relay.ir_pass.infer_type(z) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(z, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() test_concatenate() test_tuple_root() test_tuple_strided_slice() + test_stop_fusion()