Skip to content

Commit 26e788b

Browse files
masahiwweic
authored andcommitted
[Relay, OpFusion] Better tuple fusion implementation (apache#3092)
1 parent 447a76b commit 26e788b

File tree

5 files changed

+214
-67
lines changed

5 files changed

+214
-67
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ enum OpPatternKind {
4949
// Complex operation, can still fuse elemwise operations into its output.
5050
// but cannot chain another complex op
5151
kOutEWiseFusable = 4,
52+
// The pattern for tuple nodes. Can fuse into subsequent injective ops,
53+
// but treated specially
54+
kTuple = 7,
5255
// Opaque operation, cannot fuse anything.
5356
kOpaque = 8
5457
};

python/tvm/relay/op/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class OpPattern(object):
112112
COMM_REDUCE = 3
113113
# Complex op, can still fuse ewise into it
114114
OUT_ELEMWISE_FUSABLE = 4
115+
# Represents tuple node
116+
TUPLE = 7
115117
# Not fusable opaque op
116118
OPAQUE = 8
117119

src/relay/pass/fuse_ops.cc

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
267267
void VisitExpr_(const TupleNode* op) final {
268268
CHECK(graph_.node_map.count(op));
269269
Node* tuple_node = graph_.node_map.at(op);
270-
tuple_node->pattern = kInjective;
270+
tuple_node->pattern = kTuple;
271271
for (const Expr& field : op->fields) {
272272
if (field->checked_type().as<TensorTypeNode>()) {
273273
this->Update(field, tuple_node, kInjective);
@@ -661,12 +661,36 @@ class GraphPartitioner {
661661
// no actions needed if the current node have no dominator
662662
if (dom_node->parent == nullptr) continue;
663663
CHECK(!graph_node->extern_ref);
664-
// Skip if current node is already fused to the parent.
665664
size_t dom_parent_gindex = dom_node->parent->gnode->index;
665+
666+
if (phase == 2) {
667+
// Fuse injective ops into intermediate tuples, if any
668+
if (group_node->pattern > kInjective) continue;
669+
Group* dom_parent_group = groups_[dom_parent_gindex];
670+
Group* dom_root_group = dom_parent_group->FindRoot();
671+
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
672+
if (dom_root_group->pattern == kTuple) continue;
673+
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
674+
// Now we know the tuple has been fused into subsequent injective ops
675+
auto fcond = [](OpPatternKind kind, bool is_sink) {
676+
return kind <= kInjective;
677+
};
678+
// dom_root_group can also be tuple, as in inception layers
679+
// CheckPath is needed to avoid fusing two intermediate tuples
680+
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
681+
CommitFuse(graph_node, dom_node->parent->gnode);
682+
}
683+
}
684+
continue;
685+
}
686+
687+
// Skip if current node is already fused to the parent.
666688
if (groups_[dom_parent_gindex] != nullptr &&
667689
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
668690
continue;
669691
}
692+
// Do not fuse into tuple for now
693+
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
670694
// Try to fuse current node to its post-dominator.
671695
if (group_node->pattern == kOutEWiseFusable) {
672696
if (phase != 0) continue;
@@ -702,7 +726,7 @@ class GraphPartitioner {
702726
CommitFuse(graph_node, dom_node->parent->gnode);
703727
}
704728
}
705-
} else if (group_node->pattern == kInjective) {
729+
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
706730
// defer injective fusion to second phase.
707731
// so conv2d always finishes fusing.
708732
if (phase != 1) continue;
@@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
728752
// get post dominator tree
729753
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
730754
// run fusion algorithm.
731-
for (int phase = 0; phase < 2; ++phase) {
755+
for (int phase = 0; phase < 3; ++phase) {
732756
this->RunFuse(graph, post_dom_tree, phase);
733757
}
734758
return std::move(groups_);
@@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator {
821845

822846
Expr VisitExpr_(const TupleNode* tuple) {
823847
auto* ret_group = gmap_.at(tuple)->FindRoot();
824-
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
825848
if (ret_group == gmap_.at(tuple)) {
826-
// This tuple is the root of its group. Check if all fields come from other groups.
827-
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
828-
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
829-
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
830-
}
831-
if (isolated) {
832-
// Do not put a isolated tuple into a function
833-
return ExprMutator::VisitExpr_(tuple);
834-
}
835-
// This tuple has been fused with other ops before it
836-
for (size_t i = 0; i < new_fields.size(); i++) {
837-
// Copy function arguments to tuple field of the output because currently graph memory
838-
// planer doesn't support inplace operations
839-
if (new_fields[i].as<VarNode>()) {
840-
auto copy = Copy(new_fields[i]);
841-
new_fields.Set(i, copy);
842-
}
843-
}
844-
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
849+
return ExprMutator::VisitExpr_(tuple);
845850
}
846851
// This tuple is an intermediate node in the group
852+
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
847853
return TupleNode::make(new_fields);
848854
}
849855

tests/python/relay/test_backend_compile_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():
6969
relay.build(func, 'llvm')
7070

7171

72+
def test_compile_tuple_dup():
73+
x = relay.var("data", shape=(16, 16))
74+
log = relay.log(x)
75+
output = relay.Tuple([log, log])
76+
f = relay.Function([x], output)
77+
relay.build(f, 'llvm')
78+
79+
7280
if __name__ == "__main__":
7381
test_compile_engine()
7482
test_compile_placeholder_bypass()
7583
test_compile_injective_with_tuple()
76-
84+
test_compile_tuple_dup()

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 170 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,14 @@ def expected(dshape):
176176
f0 = relay.Function([x], pooled)
177177

178178
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
179-
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
180-
p1_copy = relay.copy(p1)
181179
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
182-
out = relay.Tuple((upsampled, p1_copy))
183-
f1 = relay.Function([p0, p1], out)
180+
f1 = relay.Function([p0], upsampled)
184181

185182
x = relay.var("x", shape=dshape)
186183
y = relay.Call(f0, [x])
187-
z = relay.Call(f1, [y, x])
188-
return relay.Function([x], z)
184+
z = relay.Call(f1, [y])
185+
tup = relay.Tuple((z, x))
186+
return relay.Function([x], tup)
189187

190188
dshape = (1, 16, 64, 64)
191189
z = before(dshape)
@@ -199,41 +197,6 @@ def expected(dshape):
199197
assert relay.ir_pass.alpha_equal(zz, after)
200198

201199

202-
def test_tuple_strided_slice():
203-
"""
204-
Test fusion case where the number of fields of tuple and
205-
the number of parameters to the function containing the tuple are different
206-
"""
207-
208-
def before(dshape):
209-
x = relay.var("x", shape=dshape)
210-
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
211-
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
212-
out = relay.Tuple((slice1, slice2))
213-
return relay.Function([x], out)
214-
215-
def expected(dshape):
216-
x = relay.var("x", shape=dshape)
217-
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
218-
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
219-
out = relay.Tuple((slice1, slice2))
220-
f0 = relay.Function([x], out)
221-
222-
x = relay.var("x", shape=dshape)
223-
y = relay.Call(f0, [x])
224-
return relay.Function([x], y)
225-
226-
dshape = (64, 64)
227-
z = before(dshape)
228-
z = relay.ir_pass.infer_type(z)
229-
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
230-
assert not relay.ir_pass.free_vars(zz)
231-
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
232-
zz = relay.ir_pass.infer_type(zz)
233-
assert not relay.ir_pass.free_vars(zz)
234-
after = relay.ir_pass.infer_type(expected(dshape))
235-
assert relay.ir_pass.alpha_equal(zz, after)
236-
237200

238201
def test_stop_fusion():
239202
def before(dshape):
@@ -377,13 +340,178 @@ def expected(dim):
377340
assert relay.ir_pass.alpha_equal(zz, after)
378341

379342

343+
def test_tuple_intermediate():
344+
def before(x):
345+
inj = relay.squeeze(x)
346+
y1 = relay.add(inj, relay.const(1, "float32"))
347+
tmp = relay.squeeze(inj)
348+
tmp = relay.add(tmp, relay.const(1, "float32"))
349+
y2 = relay.add(tmp, relay.const(1, "float32"))
350+
y3 = relay.add(inj, relay.const(1, "float32"))
351+
concat = relay.concatenate((y1, y2, y3), axis=1)
352+
out_inj = relay.squeeze(concat)
353+
out = relay.add(out_inj, relay.const(1, "float32"))
354+
return relay.Function(relay.ir_pass.free_vars(out), out)
355+
356+
def expected(p0):
357+
f0 = before(p0)
358+
x = relay.var("x", shape=dshape)
359+
y = relay.Call(f0, [x])
360+
return relay.Function([x], y)
361+
362+
dshape = (1, 16, 64, 64)
363+
x = relay.var("x", shape=dshape)
364+
z = before(x)
365+
z = relay.ir_pass.infer_type(z)
366+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
367+
assert not relay.ir_pass.free_vars(zz)
368+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
369+
relay.build(zz, 'llvm')
370+
zz = relay.ir_pass.infer_type(zz)
371+
assert not relay.ir_pass.free_vars(zz)
372+
after = relay.ir_pass.infer_type(expected(x))
373+
assert relay.ir_pass.alpha_equal(zz, after)
374+
375+
376+
def test_tuple_consecutive():
377+
def gen_intermediate_tuple(x):
378+
y1 = relay.add(x, relay.const(1, "float32"))
379+
y2 = relay.add(x, relay.const(1, "float32"))
380+
y3 = relay.add(x, relay.const(1, "float32"))
381+
concat = relay.concatenate((y1, y2, y3), axis=1)
382+
out = relay.add(concat, relay.const(1, "float32"))
383+
return out
384+
385+
def gen_consecutive_tuple(x):
386+
y1 = gen_intermediate_tuple(x)
387+
y2 = gen_intermediate_tuple(x)
388+
y3 = gen_intermediate_tuple(x)
389+
concat = relay.concatenate((y1, y2, y3), axis=1)
390+
return concat
391+
392+
def before(x):
393+
concat = gen_consecutive_tuple(x)
394+
pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
395+
out = relay.add(pooled, relay.const(1, "float32"))
396+
out2 = relay.add(out, relay.const(1, "float32"))
397+
out_tup = relay.Tuple((out, out2))
398+
return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)
399+
400+
def expected(dshape):
401+
p0 = relay.var("p0", shape=dshape)
402+
concat = gen_consecutive_tuple(p0)
403+
f0 = relay.Function([p0], concat)
404+
405+
p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
406+
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
407+
out = relay.add(pooled, relay.const(1, "float32"))
408+
f1 = relay.Function([p01], out)
409+
410+
p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
411+
out = relay.add(p02, relay.const(1, "float32"))
412+
f2 = relay.Function([p02], out)
413+
414+
x = relay.var("x", shape=dshape)
415+
y = relay.Call(f0, [x])
416+
z = relay.Call(f1, [y])
417+
z2 = relay.Call(f2, [z])
418+
419+
return relay.Function([x], relay.Tuple((z, z2)))
420+
421+
dshape = (1, 16, 64, 64)
422+
x = relay.var("x", shape=dshape)
423+
z = before(x)
424+
z = relay.ir_pass.infer_type(z)
425+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
426+
assert not relay.ir_pass.free_vars(zz)
427+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
428+
relay.build(zz, 'llvm')
429+
zz = relay.ir_pass.infer_type(zz)
430+
assert not relay.ir_pass.free_vars(zz)
431+
after = relay.ir_pass.infer_type(expected(dshape))
432+
assert relay.ir_pass.alpha_equal(zz, after)
433+
434+
435+
def test_inception_like():
436+
def conv(data):
437+
y = relay.nn.conv2d(data, relay.var("w"),
438+
kernel_size=(3, 3),
439+
padding=(1, 1),
440+
channels=16)
441+
return relay.nn.relu(data=y)
442+
443+
def inception_like(data):
444+
c0 = conv(data)
445+
c1 = conv(data)
446+
return relay.concatenate((c0, c1), axis=1)
447+
448+
def before(dshape):
449+
x = relay.var("x", shape=dshape)
450+
in1 = inception_like(x)
451+
in2 = inception_like(in1)
452+
return relay.Function(relay.ir_pass.free_vars(in2), in2)
453+
454+
def expected(dshape):
455+
p0 = relay.var("p0", shape=dshape)
456+
c = conv(p0)
457+
f0 = relay.Function(relay.ir_pass.free_vars(c), c)
458+
459+
p01 = relay.var("p01", shape=dshape)
460+
c = conv(p01)
461+
f1 = relay.Function(relay.ir_pass.free_vars(c), c)
462+
463+
p02 = relay.var("p02", shape=dshape)
464+
p12 = relay.var("p12", shape=dshape)
465+
concat1 = relay.concatenate((p02, p12), axis=1)
466+
f_concat1 = relay.Function([p02, p12], concat1)
467+
468+
dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
469+
470+
p03 = relay.var("p03", shape=dshape2)
471+
c = conv(p03)
472+
f2 = relay.Function(relay.ir_pass.free_vars(c), c)
473+
474+
p04 = relay.var("p04", shape=dshape2)
475+
c = conv(p04)
476+
f3 = relay.Function(relay.ir_pass.free_vars(c), c)
477+
478+
p05 = relay.var("p05", shape=dshape)
479+
p15 = relay.var("p15", shape=dshape)
480+
concat2 = relay.concatenate((p05, p15), axis=1)
481+
f_concat2 = relay.Function([p05, p15], concat2)
482+
483+
x = relay.var("x", shape=dshape)
484+
c1 = relay.Call(f0, [x, relay.var("w1")])
485+
c2 = relay.Call(f1, [x, relay.var("w2")])
486+
concat = relay.Call(f_concat1, [c1, c2])
487+
c3 = relay.Call(f2, [concat, relay.var("w3")])
488+
c4 = relay.Call(f3, [concat, relay.var("w4")])
489+
out = relay.Call(f_concat2, [c3, c4])
490+
491+
return relay.Function(relay.ir_pass.free_vars(out), out)
492+
493+
dshape = (1, 16, 64, 64)
494+
z = before(dshape)
495+
z = relay.ir_pass.infer_type(z)
496+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
497+
assert not relay.ir_pass.free_vars(zz)
498+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
499+
relay.build(zz, 'llvm')
500+
zz = relay.ir_pass.infer_type(zz)
501+
assert not relay.ir_pass.free_vars(zz)
502+
after = relay.ir_pass.infer_type(expected(dshape))
503+
assert relay.ir_pass.alpha_equal(zz, after)
504+
505+
380506
if __name__ == "__main__":
381507
test_fuse_simple()
382508
test_conv2d_fuse()
383509
test_concatenate()
384510
test_tuple_root()
385-
test_tuple_strided_slice()
386511
test_stop_fusion()
387512
test_fuse_myia_regression()
388513
test_fuse_tuple_get_elemwise()
389514
test_tuple_get_root()
515+
test_tuple_intermediate()
516+
test_tuple_consecutive()
517+
test_inception_like()

0 commit comments

Comments
 (0)