Skip to content

Commit abdccf9

Browse files
pranavjonpranav jonnalagadda-SJ1 Eng_ML
andauthored
[Relay][DefuseOps pass] bug fix: To support function body types other than call node (#10069)
Co-authored-by: pranav jonnalagadda-SJ1 Eng_ML <[email protected]>
1 parent ffff8dd commit abdccf9

File tree

2 files changed

+157
-12
lines changed

2 files changed

+157
-12
lines changed

src/relay/transforms/defuse_ops.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,15 @@ class DefuseOpsMutator : public ExprMutator {
5555

5656
if (const auto* call = new_n.as<CallNode>()) {
5757
if (const auto* func = call->op.as<FunctionNode>()) {
58-
if (func->body->IsInstance<CallNode>()) {
59-
std::unordered_map<std::string, Expr> name_to_args;
60-
for (size_t i = 0; i < func->params.size(); ++i) {
61-
const std::string& pname = func->params[i]->name_hint();
62-
ICHECK(name_to_args.cend() == name_to_args.find(pname))
63-
<< "Found multiple parameters share the same variable name `" << pname
64-
<< "` which introduces uncertainty in DefuseOps pass";
65-
name_to_args[pname] = call->args[i];
66-
}
67-
return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body);
58+
std::unordered_map<std::string, Expr> name_to_args;
59+
for (size_t i = 0; i < func->params.size(); ++i) {
60+
const std::string& pname = func->params[i]->name_hint();
61+
ICHECK(name_to_args.cend() == name_to_args.find(pname))
62+
<< "Found multiple parameters share the same variable name `" << pname
63+
<< "` which introduces uncertainty in DefuseOps pass";
64+
name_to_args[pname] = call->args[i];
6865
}
66+
return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body);
6967
}
7068
}
7169
return new_n;

tests/python/relay/test_pass_defuse_ops.py

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import numpy
18+
import pytest
1719
import tvm
1820
from tvm import relay
1921
from tvm.relay import transform
@@ -63,6 +65,151 @@ def before(dshape):
6365
assert tvm.ir.structural_equal(x, defused)
6466

6567

68+
def test_defuse_complex():
69+
"""Complex defuse testcase"""
70+
71+
def fused_conv2d_batch_norm(w):
72+
data = relay.var("data", shape=(1, 224, 224, 3))
73+
bn_gamma0 = relay.var("bn_gamma0", relay.TensorType((64,), "float32"))
74+
bn_beta0 = relay.var("bn_beta0", relay.TensorType((64,), "float32"))
75+
bn_mmean0 = relay.var("bn_mean0", relay.TensorType((64,), "float32"))
76+
bn_mvar0 = relay.var("bn_var0", relay.TensorType((64,), "float32"))
77+
c0 = relay.nn.conv2d(
78+
data,
79+
w,
80+
strides=(2, 2),
81+
padding=(3, 3, 3, 3),
82+
channels=64,
83+
kernel_size=(7, 7),
84+
data_layout="NHWC",
85+
kernel_layout="OHWI",
86+
out_layout="NHWC",
87+
)
88+
c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3)
89+
c2 = c1[0]
90+
return relay.Function(relay.analysis.free_vars(c2), c2)
91+
92+
def fused_conv2d_batch_norm_relu(z):
93+
data2 = relay.var("data2", shape=(1, 56, 56, 64))
94+
bn_gamma0 = relay.var("bn_gamma0", relay.TensorType((64,), "float32"))
95+
bn_beta0 = relay.var("bn_beta0", relay.TensorType((64,), "float32"))
96+
bn_mmean0 = relay.var("bn_mean0", relay.TensorType((64,), "float32"))
97+
bn_mvar0 = relay.var("bn_var0", relay.TensorType((64,), "float32"))
98+
c0 = relay.nn.conv2d(
99+
data2,
100+
z,
101+
padding=(1, 1, 1, 1),
102+
channels=64,
103+
kernel_size=(3, 3),
104+
data_layout="NHWC",
105+
kernel_layout="OHWI",
106+
out_layout="NHWC",
107+
)
108+
c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3)
109+
c2 = c1[0]
110+
c3 = relay.nn.relu(data=c2)
111+
return relay.Function(relay.analysis.free_vars(c3), c3)
112+
113+
def fused_max_pool2d():
114+
data1 = relay.var("data1", shape=(1, 112, 112, 64))
115+
a1 = relay.nn.max_pool2d(
116+
data1,
117+
pool_size=(3, 3),
118+
strides=(2, 2),
119+
padding=(1, 1, 1, 1),
120+
layout="NHWC",
121+
out_layout="NHWC",
122+
)
123+
return relay.Function(relay.analysis.free_vars(a1), a1)
124+
125+
def fused_add_relu():
126+
data1 = relay.var("data1", shape=(1, 56, 56, 64))
127+
data2 = relay.var("data2", shape=(1, 56, 56, 64))
128+
a0 = relay.add(data1, data2)
129+
a1 = relay.nn.relu(a0)
130+
return relay.Function(relay.analysis.free_vars(a1), a1)
131+
132+
def before_fused(conv_layer1_weight, conv_layer2_weight):
133+
data = relay.var("data", shape=(1, 3, 224, 224))
134+
data1 = relay.layout_transform(data, src_layout="NCHW", dst_layout="NHWC")
135+
bn_gamma0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
136+
bn_beta0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
137+
bn_mmean0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
138+
bn_mvar0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
139+
a0 = fused_conv2d_batch_norm(conv_layer1_weight)
140+
a1 = fused_max_pool2d()
141+
a2 = fused_conv2d_batch_norm_relu(conv_layer2_weight)
142+
a3 = fused_add_relu()
143+
y0 = relay.Call(a0, [data1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0])
144+
y1 = relay.Call(a1, [y0])
145+
y2 = relay.Call(a2, [y1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0])
146+
y3 = relay.Call(a3, [y1, y2])
147+
return relay.Function(relay.analysis.free_vars(y3), y3)
148+
149+
def golden_defused(conv_layer1_weight, conv_layer2_weight):
150+
data = relay.var("data", shape=(1, 3, 224, 224))
151+
data1 = relay.layout_transform(data, src_layout="NCHW", dst_layout="NHWC")
152+
bn_gamma0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
153+
bn_beta0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
154+
bn_mmean0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
155+
bn_mvar0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32")))
156+
c0 = relay.nn.conv2d(
157+
data1,
158+
conv_layer1_weight,
159+
strides=(2, 2),
160+
padding=(3, 3, 3, 3),
161+
channels=64,
162+
kernel_size=(7, 7),
163+
data_layout="NHWC",
164+
kernel_layout="OHWI",
165+
out_layout="NHWC",
166+
)
167+
c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3)
168+
c2 = c1[0]
169+
c3 = relay.nn.max_pool2d(
170+
c2,
171+
pool_size=(3, 3),
172+
strides=(2, 2),
173+
padding=(1, 1, 1, 1),
174+
layout="NHWC",
175+
out_layout="NHWC",
176+
)
177+
c4 = relay.nn.conv2d(
178+
c3,
179+
conv_layer2_weight,
180+
padding=(1, 1, 1, 1),
181+
channels=64,
182+
kernel_size=(3, 3),
183+
data_layout="NHWC",
184+
kernel_layout="OHWI",
185+
out_layout="NHWC",
186+
)
187+
c5 = relay.nn.batch_norm(c4, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3)
188+
c6 = c5[0]
189+
c7 = relay.nn.relu(c6)
190+
c8 = relay.add(c3, c7)
191+
c9 = relay.nn.relu(c8)
192+
return relay.Function(relay.analysis.free_vars(c9), c9)
193+
194+
# creating weight constants for the two convolution layers
195+
# in the input fused model and the golden defused model.
196+
conv_layer1_weight = relay.nn.Constant(
197+
tvm.nd.array(numpy.ndarray(shape=(64, 7, 7, 3), dtype="float32"))
198+
)
199+
conv_layer2_weight = relay.nn.Constant(
200+
tvm.nd.array(numpy.ndarray(shape=(64, 3, 3, 64), dtype="float32"))
201+
)
202+
x = before_fused(conv_layer1_weight, conv_layer2_weight)
203+
x = run_opt_pass(x, transform.InferType())
204+
defused = run_opt_pass(x, transform.DefuseOps())
205+
206+
golden1 = golden_defused(conv_layer1_weight, conv_layer2_weight)
207+
golden1 = run_opt_pass(golden1, transform.InferType())
208+
209+
assert tvm.ir.structural_equal(defused, golden1), (
210+
"Actual = \n" + str(defused) + "\nGolden = \n" + str(golden1)
211+
)
212+
213+
66214
if __name__ == "__main__":
67-
test_defuse_simple()
68-
test_inception_like()
215+
pytest.main([__file__])

0 commit comments

Comments
 (0)