Skip to content

Commit f1de26f

Browse files
vinx13AWS Neo
authored andcommitted
[RELAY] Stop_fusion annotation (apache#2624)
1 parent 3edb51a commit f1de26f

File tree

5 files changed

+80
-0
lines changed

5 files changed

+80
-0
lines changed

python/tvm/relay/op/annotation/annotation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,19 @@ def on_device(data, device):
2929
raise ValueError("device is expected to be the type of TVMContext or "
3030
"str, but received %s" % (type(device)))
3131
return _make.on_device(data, device)
32+
33+
34+
def stop_fusion(data):
35+
"""Annotate an expression to prevent it being fused with previous expressions.
36+
37+
Parameters
38+
----------
39+
data : tvm.relay.Expr
40+
The expression to be annotated.
41+
42+
Returns
43+
-------
44+
result : tvm.relay.Expr
45+
The annotated expression.
46+
"""
47+
return _make.stop_fusion(data)

src/relay/op/annotation/annotation.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <tvm/relay/expr.h>
1010
#include <tvm/relay/op.h>
1111
#include <tvm/relay/op_attr_types.h>
12+
#include <topi/elemwise.h>
1213

1314
#include "../type_relations.h"
1415
#include "../../pass/alter_op_layout.h"
@@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device")
3738
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
3839
ElemwiseArbitraryLayout);
3940

41+
Expr StopFusion(Expr data) {
42+
static const Op& op = Op::Get("annotation.stop_fusion");
43+
return CallNode::make(op, {data}, Attrs{}, {});
44+
}
45+
46+
TVM_REGISTER_API("relay.op.annotation._make.stop_fusion")
47+
.set_body_typed<Expr(Expr)>([](Expr data) {
48+
return StopFusion(data);
49+
});
50+
51+
RELAY_REGISTER_OP("annotation.stop_fusion")
52+
.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code"
53+
TVM_ADD_FILELINE)
54+
.set_num_inputs(1)
55+
.add_argument("data", "Tensor", "The input data.")
56+
.add_type_rel("Identity", IdentityRel)
57+
.set_support_level(10)
58+
.set_attr<TOpPattern>("TOpPattern", kOpaque)
59+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
60+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
61+
.set_attr<FTVMCompute>("FTVMCompute",
62+
[](const Attrs& attrs, const Array<Tensor>& inputs,
63+
const Type& out_dtype, const Target& target) -> Array<Tensor> {
64+
return {topi::identity(inputs[0])};
65+
});
4066

4167
} // namespace relay
4268
} // namespace tvm

src/relay/pass/fuse_ops.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,14 @@ class FuseMutator : private ExprMutator {
741741
}
742742
// Transform calls.
743743
Expr VisitExpr_(const CallNode* call) {
744+
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
744745
if (call->op.as<OpNode>()) {
745746
// If it is a primitive op call
746747
// then we must have a group assignment for it already.
747748
CHECK(gmap_.count(call));
749+
if (call->op.same_as(stop_fusion)) {
750+
return ExprMutator::VisitExpr(call->args[0]);
751+
}
748752
auto* ret_group = gmap_.at(call)->FindRoot();
749753
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
750754

src/relay/pass/pattern_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ Expr MakeConcatenate(Expr data, int axis);
329329

330330
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
331331

332+
Expr StopFusion(Expr data);
333+
332334
} // namespace relay
333335
} // namespace tvm
334336
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,41 @@ def expected(dshape):
220220
print(zz.astext())
221221

222222

223+
def test_stop_fusion():
224+
def before(dshape):
225+
x = relay.var("x", shape=dshape)
226+
y = relay.add(x, relay.const(1, "float32"))
227+
y = relay.annotation.stop_fusion(y)
228+
z = relay.exp(y)
229+
return relay.Function([x], z)
230+
231+
def expected(dshape):
232+
x = relay.var("p0", shape=dshape)
233+
y = relay.add(x, relay.const(1, "float32"))
234+
f1 = relay.Function([x], y)
235+
236+
x = relay.var("p01", shape=dshape)
237+
y = relay.exp(x)
238+
f2 = relay.Function([x], y)
239+
240+
x = relay.var("x", shape=dshape)
241+
y = relay.Call(f1, [x])
242+
z = relay.Call(f2, [y])
243+
return relay.Function([x], z)
244+
245+
dshape = (10, 20)
246+
z = before(dshape)
247+
z = relay.ir_pass.infer_type(z)
248+
z = relay.ir_pass.fuse_ops(z)
249+
z = relay.ir_pass.infer_type(z)
250+
after = relay.ir_pass.infer_type(expected(dshape))
251+
assert relay.ir_pass.alpha_equal(z, after)
252+
253+
223254
if __name__ == "__main__":
224255
test_fuse_simple()
225256
test_conv2d_fuse()
226257
test_concatenate()
227258
test_tuple_root()
228259
test_tuple_strided_slice()
260+
test_stop_fusion()

0 commit comments

Comments
 (0)