diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 9a556b6ce960..6c2a759db471 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -551,6 +551,22 @@ class ScheduleNode : public Node { /*! \brief Invalidate temp cache. */ void InvalidateCache(); + /*! + * \brief Check if the schedule contains an Operation. + * \param op The candidate Operation. + * \return true if the schedule has the Operation. Otherwise, false. + */ + EXPORT bool Contain(const Operation& op) const; + + /*! + * \brief Check if the schedule contains a Tensor. + * \param tensor The candidate tensor. + * \return true if the schedule has the tensor. Otherwise, false. + */ + EXPORT bool Contain(const Tensor& tensor) const { + return Contain(tensor->op); + } + /*! * \brief Create a schedule for array of ops(and their dependencies). * \param ops The ops to be scheduled. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 43515105bd94..4b5842c36020 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -127,7 +127,9 @@ class ScheduleGetter : schedule = fschedule[master_op_](master_attrs_, tensor_outs, target_); for (const auto& scalar : scalars_) { - schedule[scalar].compute_inline(); + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } } } return std::make_pair(schedule, cfunc); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 55d609872929..fc7aad6ce515 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -715,10 +715,13 @@ class GraphPartitioner { // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { if (!is_sink) { - return kind <= kBroadcast; + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast master. + return kind <= kInjective; } else { return (kind <= kBroadcast || kind == kCommReduce || + kind == kInjective || kind == kOutEWiseFusable); } }; diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index ffee804198b6..e1cb4c5f9bdc 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -712,6 +712,10 @@ void ScheduleNode::InitCache() { CHECK_EQ(op2stage_cache_.size(), stages.size()); } +bool ScheduleNode::Contain(const Operation& op) const { + return stage_map.find(op) != stage_map.end(); +} + Schedule ScheduleNode::make(Array ops) { auto n = make_node(); Schedule sch(n); diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index bdffdf7c129f..6d6781046a10 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -23,13 +23,15 @@ def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) - return relay.Function([x], z) + w = relay.squeeze(z) + return relay.Function([x], w) def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) - f1 = relay.Function([x], z) + w = relay.squeeze(z) + f1 = relay.Function([x], w) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -503,6 +505,38 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_fuse_parallel_injective(): + """Test fusing parallel injective ops to an elemwise op.""" + def before(): + x = relay.var("x", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.squeeze(y) + u = relay.transpose(y, axes=[0, 1]) + w = relay.left_shift(z, u) + return relay.Function([x], w) + + def expected(): + x = relay.var("p", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.squeeze(y) + u = relay.transpose(y, axes=[0, 1]) + w = relay.left_shift(z, u) + f1 = relay.Function([x], w) + x = relay.var("x", shape=(10, 20)) + y = relay.Call(f1, [x]) + return relay.Function([x], y) + + z = before() + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -515,3 +549,4 @@ def expected(dshape): test_tuple_intermediate() test_tuple_consecutive() test_inception_like() + test_fuse_parallel_injective()