Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,22 @@ class ScheduleNode : public Node {
/*! \brief Invalidate temp cache. */
void InvalidateCache();

/*!
* \brief Check if the schedule contains an Operation.
* \param op The candidate Operation.
* \return true if the schedule has the Operation. Otherwise, false.
*/
EXPORT bool Contain(const Operation& op) const;

/*!
* \brief Check if the schedule contains a Tensor.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
EXPORT bool Contain(const Tensor& tensor) const {
return Contain(tensor->op);
}

/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ class ScheduleGetter :
schedule =
fschedule[master_op_](master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline();
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
return std::make_pair(schedule, cfunc);
Expand Down
5 changes: 4 additions & 1 deletion src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,13 @@ class GraphPartitioner {
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast master.
return kind <= kInjective;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kInjective ||
kind == kOutEWiseFusable);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this fix is correct, but in this case we can just simplify this condition to kind <= kOutEwiseFusable
I wonder what is the original reason for leaving out kInjective @tqchen

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi yes, we can. I wasn’t quite sure if it was intentionally left out. If not, I’ll change to <=

Copy link
Member

@tqchen tqchen Mar 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I overlooked. However, it is good to leave things as they are because this way it is more clear. We also need to add a case to fuse into multiple injective ops, i.e. need to enhance condition in the if branch as well. Please add a test-case on that as well
ewise->parallel{injective, injective}->injective

};
Expand Down
4 changes: 4 additions & 0 deletions src/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,10 @@ void ScheduleNode::InitCache() {
CHECK_EQ(op2stage_cache_.size(), stages.size());
}

bool ScheduleNode::Contain(const Operation& op) const {
return stage_map.find(op) != stage_map.end();
}

Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = make_node<ScheduleNode>();
Schedule sch(n);
Expand Down
39 changes: 37 additions & 2 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
return relay.Function([x], z)
w = relay.squeeze(z)
return relay.Function([x], w)

def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
f1 = relay.Function([x], z)
w = relay.squeeze(z)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
Expand Down Expand Up @@ -503,6 +505,38 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_fuse_parallel_injective():
"""Test fusing parallel injective ops to an elemwise op."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
return relay.Function([x], w)

def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)

z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -515,3 +549,4 @@ def expected(dshape):
test_tuple_intermediate()
test_tuple_consecutive()
test_inception_like()
test_fuse_parallel_injective()