Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,19 @@ def on_device(data, device):
raise ValueError("device is expected to be the type of TVMContext or "
"str, but received %s" % (type(device)))
return _make.on_device(data, device)


def stop_fusion(data):
"""Annotate an expression to prevent it being fused with previous expressions.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.stop_fusion(data)
26 changes: 26 additions & 0 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>

#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
Expand Down Expand Up @@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);

Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
return CallNode::make(op, {data}, Attrs{}, {});
}

TVM_REGISTER_API("relay.op.annotation._make.stop_fusion")
.set_body_typed<Expr(Expr)>([](Expr data) {
return StopFusion(data);
});

RELAY_REGISTER_OP("annotation.stop_fusion")
.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});

} // namespace relay
} // namespace tvm
4 changes: 4 additions & 0 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,14 @@ class FuseMutator : private ExprMutator {
}
// Transform calls.
Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) {
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
if (call->op.same_as(stop_fusion)) {
return ExprMutator::VisitExpr(call->args[0]);
}
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args = GetNewArguments(call->args, ret_group);

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ Expr MakeConcatenate(Expr data, int axis);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);

Expr StopFusion(Expr data);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
32 changes: 32 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,41 @@ def expected(dshape):
print(zz.astext())


def test_stop_fusion():
def before(dshape):
x = relay.var("x", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
y = relay.annotation.stop_fusion(y)
z = relay.exp(y)
return relay.Function([x], z)

def expected(dshape):
x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
f1 = relay.Function([x], y)

x = relay.var("p01", shape=dshape)
y = relay.exp(x)
f2 = relay.Function([x], y)

x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x])
z = relay.Call(f2, [y])
return relay.Function([x], z)

dshape = (10, 20)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
z = relay.ir_pass.fuse_ops(z)
z = relay.ir_pass.infer_type(z)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(z, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion()