Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,49 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
return std::move(groups_);
}

class RemoveRootTupleVisitor : ExprVisitor {
public:
RemoveRootTupleVisitor(const std::unordered_map<const Node*, GraphPartitioner::Group*>& gmap,
const Expr& body)
: body_(body), gmap_(gmap) {}

void UpdateNodeEntries(IndexedForwardGraph* graph) {
graph_ = graph;
this->VisitExpr(body_);
}

void VisitExpr_(const TupleNode* tuple) {
const GraphPartitioner::Group* tuple_group = gmap_.at(tuple)->FindRoot();
if (tuple_group == gmap_.at(tuple)) {
IndexedForwardGraph::Node* tuple_node = graph_->node_map.at(tuple);
tuple_node->pattern = kOpaque;
for (auto field : tuple->fields) {
IndexedForwardGraph::Node* tuple_filed_node = graph_->node_map.at(field.get());
tuple_filed_node->extern_ref = true;
}
}
}

private:
const Expr& body_;
const std::unordered_map<const Node*, GraphPartitioner::Group*>& gmap_;
IndexedForwardGraph* graph_;
};


class FuseMutator : private ExprMutator {
public:
// Run the transform
Expr Transform(const Expr& body, int fuse_opt_level) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(
graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
}
AssignGroups(graph, fuse_opt_level);
// Detect and remove tuple nodes that are roots in their groups
// This is to prevent making a fused function that returns a tuple
RemoveRootTupleVisitor(gmap_, body).UpdateNodeEntries(&graph);
// Reassign new groups where detected tuples are no longer roots
AssignGroups(graph, fuse_opt_level);

// The following line can be used for debug.
// this->DebugDumpGroup(body);
return this->Mutate(body);
Expand Down Expand Up @@ -821,29 +852,12 @@ class FuseMutator : private ExprMutator {

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
if (ret_group == gmap_.at(tuple)) {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
}
}
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
// Do not fuse a tuple if it is the return value
return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields);
}

Expand All @@ -864,6 +878,15 @@ class FuseMutator : private ExprMutator {
return new_node;
}

void AssignGroups(const IndexedForwardGraph& graph, int fuse_opt_level) {
auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
gmap_.clear();
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
}
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
Expand Down
10 changes: 9 additions & 1 deletion tests/python/relay/test_backend_compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():
relay.build(func, 'llvm')


def test_compile_tuple_dup():
x = relay.var("data", shape=(16, 16))
log = relay.log(x)
output = relay.Tuple([log, log])
f = relay.Function([x], output)
relay.build(f, 'llvm')


if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()

test_compile_tuple_dup()
47 changes: 4 additions & 43 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,14 @@ def expected(dshape):
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1_copy))
f1 = relay.Function([p0, p1], out)
f1 = relay.Function([p0], upsampled)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)
z = relay.Call(f1, [y])
tup = relay.Tuple((z, x))
return relay.Function([x], tup)

dshape = (1, 16, 64, 64)
z = before(dshape)
Expand All @@ -199,42 +197,6 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_strided_slice():
Copy link
Member Author

@masahi masahi Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case is removed, because now this is basically the same test case as above test (test_tuple_root).

"""
Test fusion case where the number of fields of tuple and
the number of parameters to the function containing the tuple are different
"""

def before(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
return relay.Function([x], out)

def expected(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
f0 = relay.Function([x], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)

dshape = (64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_stop_fusion():
def before(dshape):
x = relay.var("x", shape=dshape)
Expand Down Expand Up @@ -382,7 +344,6 @@ def expected(dim):
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
Expand Down