Skip to content

Commit f88f458

Browse files
zhiicsmasahi
authored andcommitted
[RELAY][FUSION] Enhance fusion rule that starts from elemwise and broadcast (#2932)
* [relay][bugfix] fuse injective to elemwise and broadcast * enhance fusion for prarllel injectiveOD * check if tensor in schedule * fix codegen * fix lint * update * lint
1 parent 977896c commit f88f458

File tree

5 files changed

+64
-4
lines changed

5 files changed

+64
-4
lines changed

include/tvm/schedule.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,22 @@ class ScheduleNode : public Node {
551551
/*! \brief Invalidate temp cache. */
552552
void InvalidateCache();
553553

554+
/*!
555+
* \brief Check if the schedule contains an Operation.
556+
* \param op The candidate Operation.
557+
* \return true if the schedule has the Operation. Otherwise, false.
558+
*/
559+
EXPORT bool Contain(const Operation& op) const;
560+
561+
/*!
562+
* \brief Check if the schedule contains a Tensor.
563+
* \param tensor The candidate tensor.
564+
* \return true if the schedule has the tensor. Otherwise, false.
565+
*/
566+
EXPORT bool Contain(const Tensor& tensor) const {
567+
return Contain(tensor->op);
568+
}
569+
554570
/*!
555571
* \brief Create a schedule for array of ops(and their dependencies).
556572
* \param ops The ops to be scheduled.

src/relay/backend/compile_engine.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ class ScheduleGetter :
127127
schedule =
128128
fschedule[master_op_](master_attrs_, tensor_outs, target_);
129129
for (const auto& scalar : scalars_) {
130-
schedule[scalar].compute_inline();
130+
if (schedule->Contain(scalar)) {
131+
schedule[scalar].compute_inline();
132+
}
131133
}
132134
}
133135
return std::make_pair(schedule, cfunc);

src/relay/pass/fuse_ops.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,10 +715,13 @@ class GraphPartitioner {
715715
// The final terminal node can already be fused to a OutEWiseFusable group.
716716
auto fcond = [](OpPatternKind kind, bool is_sink) {
717717
if (!is_sink) {
718-
return kind <= kBroadcast;
718+
// Elemwise, broadcast, and injective ops on the parallel branches
719+
// are allowed be fused to the elemwise/broadcast master.
720+
return kind <= kInjective;
719721
} else {
720722
return (kind <= kBroadcast ||
721723
kind == kCommReduce ||
724+
kind == kInjective ||
722725
kind == kOutEWiseFusable);
723726
}
724727
};

src/schedule/schedule_lang.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,10 @@ void ScheduleNode::InitCache() {
712712
CHECK_EQ(op2stage_cache_.size(), stages.size());
713713
}
714714

715+
bool ScheduleNode::Contain(const Operation& op) const {
716+
return stage_map.find(op) != stage_map.end();
717+
}
718+
715719
Schedule ScheduleNode::make(Array<Operation> ops) {
716720
auto n = make_node<ScheduleNode>();
717721
Schedule sch(n);

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ def before():
2323
x = relay.var("x", shape=(10, 20))
2424
y = relay.add(x, relay.const(1, "float32"))
2525
z = relay.exp(y)
26-
return relay.Function([x], z)
26+
w = relay.squeeze(z)
27+
return relay.Function([x], w)
2728

2829
def expected():
2930
x = relay.var("p", shape=(10, 20))
3031
y = relay.add(x, relay.const(1, "float32"))
3132
z = relay.exp(y)
32-
f1 = relay.Function([x], z)
33+
w = relay.squeeze(z)
34+
f1 = relay.Function([x], w)
3335
x = relay.var("x", shape=(10, 20))
3436
y = relay.Call(f1, [x])
3537
return relay.Function([x], y)
@@ -503,6 +505,38 @@ def expected(dshape):
503505
assert relay.ir_pass.alpha_equal(zz, after)
504506

505507

508+
def test_fuse_parallel_injective():
509+
"""Test fusing parallel injective ops to an elemwise op."""
510+
def before():
511+
x = relay.var("x", shape=(10, 20))
512+
y = relay.add(x, relay.const(1, "float32"))
513+
z = relay.squeeze(y)
514+
u = relay.transpose(y, axes=[0, 1])
515+
w = relay.left_shift(z, u)
516+
return relay.Function([x], w)
517+
518+
def expected():
519+
x = relay.var("p", shape=(10, 20))
520+
y = relay.add(x, relay.const(1, "float32"))
521+
z = relay.squeeze(y)
522+
u = relay.transpose(y, axes=[0, 1])
523+
w = relay.left_shift(z, u)
524+
f1 = relay.Function([x], w)
525+
x = relay.var("x", shape=(10, 20))
526+
y = relay.Call(f1, [x])
527+
return relay.Function([x], y)
528+
529+
z = before()
530+
z = relay.ir_pass.infer_type(z)
531+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
532+
assert not relay.ir_pass.free_vars(zz)
533+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
534+
zz = relay.ir_pass.infer_type(zz)
535+
assert not relay.ir_pass.free_vars(zz)
536+
after = relay.ir_pass.infer_type(expected())
537+
assert relay.ir_pass.alpha_equal(zz, after)
538+
539+
506540
if __name__ == "__main__":
507541
test_fuse_simple()
508542
test_conv2d_fuse()
@@ -515,3 +549,4 @@ def expected(dshape):
515549
test_tuple_intermediate()
516550
test_tuple_consecutive()
517551
test_inception_like()
552+
test_fuse_parallel_injective()

0 commit comments

Comments
 (0)