From 6272d6d2ade8723db58c91bce73b2542cca9e1d8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Apr 2019 22:55:50 +0900 Subject: [PATCH 1/4] remove root tuple --- src/relay/pass/fuse_ops.cc | 71 ++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 12e3174dcade..48d622bff806 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -734,18 +734,45 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { return std::move(groups_); } +class RemoveRootTupleVisitor : ExprVisitor { + public: + RemoveRootTupleVisitor(const std::unordered_map& gmap, + IndexedForwardGraph& graph) + : gmap_(gmap), graph_(graph) {} + + void UpdateNodeEntries(const Expr& body) { this->VisitExpr(body); } + + void VisitExpr_(const TupleNode* tuple) { + const GraphPartitioner::Group* tuple_group = gmap_.at(tuple)->FindRoot(); + if (tuple_group == gmap_.at(tuple)) { + IndexedForwardGraph::Node* tuple_node = graph_.node_map[tuple]; + tuple_node->pattern = kOpaque; + for (auto field : tuple->fields) { + IndexedForwardGraph::Node* tuple_filed_node = graph_.node_map[field.get()]; + tuple_filed_node->extern_ref = true; + } + } + } + + private: + const std::unordered_map& gmap_; + IndexedForwardGraph& graph_; +}; + + class FuseMutator : private ExprMutator { public: // Run the transform Expr Transform(const Expr& body, int fuse_opt_level) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( - graph); - for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { - CHECK(graph.post_dfs_order[nid]->ref != nullptr); - gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; - } + AssignGroups(graph, fuse_opt_level); + // Detect and remove tuple nodes that are roots in their groups + // This is to prevent making a fused function that returns a tuple + RemoveRootTupleVisitor(gmap_, graph).UpdateNodeEntries(body); + // Reassign new groups where detected tuples are no longer roots + AssignGroups(graph, fuse_opt_level); + // The following line can be used for debug. // this->DebugDumpGroup(body); return this->Mutate(body); @@ -821,29 +848,12 @@ class FuseMutator : private ExprMutator { Expr VisitExpr_(const TupleNode* tuple) { auto* ret_group = gmap_.at(tuple)->FindRoot(); - Array new_fields = GetNewArguments(tuple->fields, ret_group); if (ret_group == gmap_.at(tuple)) { - // This tuple is the root of its group. Check if all fields come from other groups. - bool isolated = new_fields.size() == ginfo_[ret_group].params.size(); - for (size_t i = 0; i < new_fields.size() && isolated; ++i) { - isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); - } - if (isolated) { - // Do not put a isolated tuple into a function - return ExprMutator::VisitExpr_(tuple); - } - // This tuple has been fused with other ops before it - for (size_t i = 0; i < new_fields.size(); i++) { - // Copy function arguments to tuple field of the output because currently graph memory - // planer doesn't support inplace operations - if (new_fields[i].as()) { - auto copy = Copy(new_fields[i]); - new_fields.Set(i, copy); - } - } - return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields)); + // Do not fuse a tuple if it is the return value + return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group + Array new_fields = GetNewArguments(tuple->fields, ret_group); return TupleNode::make(new_fields); } @@ -864,6 +874,15 @@ class FuseMutator : private ExprMutator { return new_node; } + void AssignGroups(const IndexedForwardGraph& graph, int fuse_opt_level) { + auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph); + gmap_.clear(); + for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { + CHECK(graph.post_dfs_order[nid]->ref != nullptr); + gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; + } + } + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { From ad1b757608a699471d624d624ee256c0e50eb44c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Apr 2019 23:27:21 +0900 Subject: [PATCH 2/4] update tests --- .../relay/test_backend_compile_engine.py | 10 +++- tests/python/relay/test_pass_fuse_ops.py | 47 ++----------------- 2 files changed, 13 insertions(+), 44 deletions(-) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 3b479b847619..ca4619c97886 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -69,8 +69,16 @@ def test_compile_injective_with_tuple(): relay.build(func, 'llvm') +def test_compile_tuple_dup(): + x = relay.var("data", shape=(16, 16)) + log = relay.log(x) + output = relay.Tuple([log, log]) + f = relay.Function([x], output) + relay.build(f, 'llvm') + + if __name__ == "__main__": test_compile_engine() test_compile_placeholder_bypass() test_compile_injective_with_tuple() - + test_compile_tuple_dup() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index baafbeebd560..a44b8051b222 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -176,16 +176,14 @@ def expected(dshape): f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) - p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) - p1_copy = relay.copy(p1) upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") - out = relay.Tuple((upsampled, p1_copy)) - f1 = relay.Function([p0, p1], out) + f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) - z = relay.Call(f1, [y, x]) - return relay.Function([x], z) + z = relay.Call(f1, [y]) + tup = relay.Tuple((z, x)) + return relay.Function([x], tup) dshape = (1, 16, 64, 64) z = before(dshape) @@ -199,42 +197,6 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) -def test_tuple_strided_slice(): - """ - Test fusion case where the number of fields of tuple and - the number of parameters to the function containing the tuple are different - """ - - def before(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - return relay.Function([x], out) - - def expected(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - f0 = relay.Function([x], out) - - x = relay.var("x", shape=dshape) - y = relay.Call(f0, [x]) - return relay.Function([x], y) - - dshape = (64, 64) - z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) - - def test_stop_fusion(): def before(dshape): x = relay.var("x", shape=dshape) @@ -382,7 +344,6 @@ def expected(dim): test_conv2d_fuse() test_concatenate() test_tuple_root() - test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() test_fuse_tuple_get_elemwise() From d7bd4aa7d430e97666eee8ffccd387e0ac3e024d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Apr 2019 00:09:55 +0900 Subject: [PATCH 3/4] fix lint --- src/relay/pass/fuse_ops.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 48d622bff806..0ba7107beb8f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -737,26 +737,30 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { class RemoveRootTupleVisitor : ExprVisitor { public: RemoveRootTupleVisitor(const std::unordered_map& gmap, - IndexedForwardGraph& graph) - : gmap_(gmap), graph_(graph) {} + const Expr& body) + : body_(body), gmap_(gmap) {} - void UpdateNodeEntries(const Expr& body) { this->VisitExpr(body); } + void UpdateNodeEntries(IndexedForwardGraph* graph) { + graph_ = graph; + this->VisitExpr(body_); + } void VisitExpr_(const TupleNode* tuple) { const GraphPartitioner::Group* tuple_group = gmap_.at(tuple)->FindRoot(); if (tuple_group == gmap_.at(tuple)) { - IndexedForwardGraph::Node* tuple_node = graph_.node_map[tuple]; + IndexedForwardGraph::Node* tuple_node = graph_->node_map[tuple]; tuple_node->pattern = kOpaque; for (auto field : tuple->fields) { - IndexedForwardGraph::Node* tuple_filed_node = graph_.node_map[field.get()]; + IndexedForwardGraph::Node* tuple_filed_node = graph_->node_map[field.get()]; tuple_filed_node->extern_ref = true; } } } private: + const Expr& body_; const std::unordered_map& gmap_; - IndexedForwardGraph& graph_; + IndexedForwardGraph* graph_; }; @@ -769,7 +773,7 @@ class FuseMutator : private ExprMutator { AssignGroups(graph, fuse_opt_level); // Detect and remove tuple nodes that are roots in their groups // This is to prevent making a fused function that returns a tuple - RemoveRootTupleVisitor(gmap_, graph).UpdateNodeEntries(body); + RemoveRootTupleVisitor(gmap_, body).UpdateNodeEntries(&graph); // Reassign new groups where detected tuples are no longer roots AssignGroups(graph, fuse_opt_level); From a0416748fde4e1d891efd5ed39ae18ceb91bac0d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Apr 2019 00:39:45 +0900 Subject: [PATCH 4/4] use at --- src/relay/pass/fuse_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 0ba7107beb8f..9d5dc67bd8e2 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -748,10 +748,10 @@ class RemoveRootTupleVisitor : ExprVisitor { void VisitExpr_(const TupleNode* tuple) { const GraphPartitioner::Group* tuple_group = gmap_.at(tuple)->FindRoot(); if (tuple_group == gmap_.at(tuple)) { - IndexedForwardGraph::Node* tuple_node = graph_->node_map[tuple]; + IndexedForwardGraph::Node* tuple_node = graph_->node_map.at(tuple); tuple_node->pattern = kOpaque; for (auto field : tuple->fields) { - IndexedForwardGraph::Node* tuple_filed_node = graph_->node_map[field.get()]; + IndexedForwardGraph::Node* tuple_filed_node = graph_->node_map.at(field.get()); tuple_filed_node->extern_ref = true; } }