Skip to content

Commit 93a272c

Browse files
masahiwweic
authored andcommitted
[Relay] Add support for TupleGetItem in op fusion (apache#2914)
1 parent f4e3837 commit 93a272c

File tree

3 files changed

+157
-4
lines changed

3 files changed

+157
-4
lines changed

src/relay/pass/fuse_ops.cc

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,30 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
261261
}
262262

263263
void VisitExpr_(const TupleGetItemNode* op) final {
264-
CHECK(graph_.node_map.count(op));
265-
Node* node = graph_.node_map.at(op);
266-
this->Update(op->tuple, node, kOpaque);
264+
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
265+
CHECK(tuple_type);
266+
// If this tuple contain a reference type, and we fuse TupleGetItem and
267+
// the reference, a fused function will have a tuple containing a reference
268+
// in its parameters. But when TVM lowers a fused function, it expects all
269+
// arguments to be a Tensor or a tuple containing only Tensors.
270+
// To avoid modifying codegen logic, we do not allow fusing through a reference.
271+
// The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op)
272+
// below and corresponding visitor methods
273+
bool has_reference = false;
274+
for (auto ty : tuple_type->fields) {
275+
if (ty.as<RefTypeNode>()) {
276+
has_reference = true;
277+
break;
278+
}
279+
}
280+
if (has_reference) {
281+
this->Update(op->tuple, nullptr, kOpaque);
282+
} else {
283+
CHECK(graph_.node_map.count(op));
284+
Node* node = graph_.node_map.at(op);
285+
node->pattern = kInjective;
286+
this->Update(op->tuple, node, kInjective);
287+
}
267288
ExprVisitor::VisitExpr_(op);
268289
this->AddNode(op);
269290
}
@@ -809,6 +830,23 @@ class FuseMutator : private ExprMutator {
809830
return TupleNode::make(new_fields);
810831
}
811832

833+
Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
834+
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
835+
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
836+
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
837+
if (ret_group == gmap_.at(tuple_get)) {
838+
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
839+
// Isolated. This case occurs when tuple is created by an Opaque op
840+
// e.g. multibox_transform_loc
841+
return ExprMutator::VisitExpr_(tuple_get);
842+
}
843+
// A new function whose output is a tuple field access
844+
return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
845+
}
846+
// This is an intermediate node in the group
847+
return new_node;
848+
}
849+
812850
Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
813851
const GroupInfo& ginfo = ginfo_[group];
814852
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});

tests/python/relay/test_backend_graph_runtime.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tvm.relay.scope_builder import ScopeBuilder
88
from tvm.relay.op import add
99
from tvm.relay.module import Module
10+
from tvm.relay.testing.config import ctx_list
1011

1112
# @tq, @jr should we put this in testing ns?
1213
def check_rts(expr, args, expected_result, mod=None):
@@ -127,9 +128,47 @@ def test_plan_memory():
127128
assert len(device_types) == 1
128129

129130

131+
def test_gru_like():
132+
def unit(rnn_dim):
133+
X = relay.var("X", shape=(1, rnn_dim))
134+
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
135+
matmul = relay.nn.dense(X, W)
136+
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
137+
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
138+
return relay.Function([X, W], out)
139+
140+
def sigmoid(x):
141+
return 1 / (1 + np.exp(-x))
142+
143+
def unit_numpy(X, W):
144+
prod = np.dot(X, W.transpose())
145+
splits = np.split(prod, indices_or_sections=3, axis=1)
146+
return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2])
147+
148+
dtype = "float32"
149+
rnn_dim = 1000
150+
x = np.random.rand(1, rnn_dim).astype(dtype)
151+
y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
152+
out_shape = (1, rnn_dim)
153+
z = unit(rnn_dim)
154+
155+
for target, ctx in ctx_list():
156+
with relay.build_config(opt_level=2):
157+
graph, lib, params = relay.build(z, target)
158+
m = graph_runtime.create(graph, lib, ctx)
159+
m.set_input("X", tvm.nd.array(x.astype(dtype)))
160+
m.set_input("y", tvm.nd.array(y.astype(dtype)))
161+
m.set_input(**params)
162+
m.run()
163+
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
164+
ref = unit_numpy(x, y)
165+
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
166+
167+
130168
if __name__ == "__main__":
131169
test_plan_memory()
132170
test_with_params()
133171
test_add_op_scalar()
134172
test_add_op_tensor()
135173
test_add_op_broadcast()
174+
test_gru_like()

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def expected(dshape):
217217
assert not relay.ir_pass.free_vars(zz)
218218
after = relay.ir_pass.infer_type(expected(dshape))
219219
assert relay.ir_pass.alpha_equal(zz, after)
220-
print(zz.astext())
221220

222221

223222
def test_stop_fusion():
@@ -287,6 +286,81 @@ def expected(dshape, dtype):
287286
assert relay.ir_pass.alpha_equal(f, after)
288287

289288

289+
def test_fuse_tuple_get_elemwise():
290+
def before(dim):
291+
X = relay.var("X", shape=(1, dim))
292+
W = relay.var("W", shape=(3 * dim, dim))
293+
matmul = relay.nn.dense(X, W)
294+
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
295+
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
296+
return relay.Function([X, W], out)
297+
298+
def expected(dim):
299+
p0 = relay.var("p0", shape=(1, dim))
300+
p1 = relay.var("p1", shape=(3 * dim, dim))
301+
matmul = relay.nn.dense(p0, p1)
302+
f0 = relay.Function([p0, p1], matmul)
303+
304+
p01 = relay.var("p01", shape=(1, 3 * dim))
305+
splitted = relay.split(p01, indices_or_sections=3, axis=1)
306+
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
307+
f1 = relay.Function([p01], out)
308+
309+
X = relay.var("X", shape=(1, dim))
310+
W = relay.var("W", shape=(3 * dim, dim))
311+
y = relay.Call(f0, [X, W])
312+
z = relay.Call(f1, [y])
313+
return relay.Function([X, W], z)
314+
315+
dim = 10
316+
z = before(dim)
317+
z = relay.ir_pass.infer_type(z)
318+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
319+
assert not relay.ir_pass.free_vars(zz)
320+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
321+
zz = relay.ir_pass.infer_type(zz)
322+
assert not relay.ir_pass.free_vars(zz)
323+
after = relay.ir_pass.infer_type(expected(dim))
324+
assert relay.ir_pass.alpha_equal(zz, after)
325+
326+
327+
def test_tuple_get_root():
328+
def before(dim):
329+
X = relay.var("X", shape=(1, 3 * dim))
330+
W = relay.var("W", shape=(dim, dim))
331+
splitted = relay.split(X, indices_or_sections=3, axis=1)
332+
out = relay.nn.dense(splitted[0], W)
333+
return relay.Function([X, W], out)
334+
335+
def expected(dim):
336+
p0 = relay.var("p0", shape=(1, 3 * dim))
337+
splitted = relay.split(p0, indices_or_sections=3, axis=1)
338+
out = splitted[0]
339+
f0 = relay.Function([p0], out)
340+
341+
p01 = relay.var("p01", shape=(1, dim))
342+
p1 = relay.var("p1", shape=(dim, dim))
343+
out = relay.nn.dense(p01, p1)
344+
f1 = relay.Function([p01, p1], out)
345+
346+
X = relay.var("X", shape=(1, 3 * dim))
347+
W = relay.var("W", shape=(dim, dim))
348+
y = relay.Call(f0, [X])
349+
z = relay.Call(f1, [y, W])
350+
return relay.Function([X, W], z)
351+
352+
dim = 10
353+
z = before(dim)
354+
z = relay.ir_pass.infer_type(z)
355+
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
356+
assert not relay.ir_pass.free_vars(zz)
357+
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
358+
zz = relay.ir_pass.infer_type(zz)
359+
assert not relay.ir_pass.free_vars(zz)
360+
after = relay.ir_pass.infer_type(expected(dim))
361+
assert relay.ir_pass.alpha_equal(zz, after)
362+
363+
290364
if __name__ == "__main__":
291365
test_fuse_simple()
292366
test_conv2d_fuse()
@@ -295,3 +369,5 @@ def expected(dshape, dtype):
295369
test_tuple_strided_slice()
296370
test_stop_fusion()
297371
test_fuse_myia_regression()
372+
test_fuse_tuple_get_elemwise()
373+
test_tuple_get_root()

0 commit comments

Comments
 (0)