Skip to content

Commit d6c8862

Browse files
author
Ivy
committed
add unit test for byoc-dnnl
1 parent b35fc83 commit d6c8862

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import tvm.ir
3636
from ...dataflow_pattern import wildcard, is_op
3737
from .register import register_pattern_table
38+
from tvm.relay import transform
39+
from tvm.relay.build_module import bind_params_by_name
3840

3941

4042
def _register_external_op_helper(op_name, supported=True):
@@ -85,3 +87,36 @@ def pattern_table():
8587
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
8688
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
8789
return dnnl_patterns
90+
91+
92+
def partition_for_dnnl(
93+
mod,
94+
params=None,
95+
):
96+
"""Partition the graph greedily offloading supported operators to DNNL.
97+
98+
Parameters
99+
----------
100+
mod : Module
101+
The module to run passes on.
102+
params : Optional[Dict[str, NDArray]]
103+
Constant input parameters.
104+
Returns
105+
-------
106+
mod : Module
107+
Annotated and partitioned module.
108+
"""
109+
110+
if params:
111+
mod["main"] = bind_params_by_name(mod["main"], params)
112+
seq = tvm.transform.Sequential(
113+
[
114+
transform.MergeComposite(pattern_table()),
115+
transform.AnnotateTarget("dnnl"),
116+
transform.MergeCompilerRegions(),
117+
transform.PartitionGraph(),
118+
]
119+
)
120+
with tvm.transform.PassContext(opt_level=3):
121+
mod = seq(mod)
122+
return mod

tests/python/contrib/test_dnnl.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import mxnet as mx
18+
from gluoncv.model_zoo import get_model
19+
20+
import numpy as np
21+
import pytest
22+
import itertools
23+
24+
import tvm
25+
import tvm.relay.testing
26+
from tvm import relay
27+
from tvm.relay.op.contrib import dnnl
28+
import tvm.testing
29+
import argparse
30+
31+
has_dnnl_codegen = pytest.mark.skipif(
32+
not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available"
33+
)
34+
35+
run_module = tvm.testing.parameter(
36+
pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]),
37+
pytest.param(
38+
True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]
39+
),
40+
ids=["compile", "run"],
41+
)
42+
43+
44+
def vmobj_to_list(o):
45+
if isinstance(o, tvm.nd.NDArray):
46+
return [o.numpy()]
47+
elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
48+
return [vmobj_to_list(f) for f in o]
49+
else:
50+
raise RuntimeError("Unknown object type: %s" % type(o))
51+
52+
53+
def assert_result_dict_holds(result_dict):
54+
for k1, k2 in itertools.combinations(result_dict, 2):
55+
res1 = vmobj_to_list(result_dict[k1])
56+
res2 = vmobj_to_list(result_dict[k2])
57+
for r1, r2 in zip(res1, res2):
58+
tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
59+
60+
61+
def run_and_verify_func(config, target="llvm", run_module=True):
62+
"""Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
63+
64+
Parameters
65+
----------
66+
config : Tuple[relay.Function, Dict[str, NDArray], List[str]]
67+
A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and
68+
3) A list of which vars should be considered params.
69+
70+
run_module: bool
71+
If True, the built module will be run after being compiled.
72+
"""
73+
f, input_shapes, is_param = config
74+
params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param}
75+
input_dict = {
76+
k: np.random.uniform(-1, 1, v).astype(np.float32)
77+
for k, v in input_shapes.items()
78+
if k not in is_param
79+
}
80+
dev = tvm.device(target)
81+
82+
result_dict = dict()
83+
for mode in ["graph", "vm"]:
84+
for use_dnnl in [False, True]:
85+
mod = tvm.IRModule()
86+
mod["main"] = f
87+
result_key = mode + ("_dnnl" if use_dnnl else "")
88+
if use_dnnl:
89+
mod = dnnl.partition_for_dnnl(mod, params)
90+
with tvm.transform.PassContext(opt_level=3):
91+
func = relay.create_executor(
92+
mode, mod=mod, device=dev, target=target
93+
).evaluate()
94+
else:
95+
with tvm.transform.PassContext(opt_level=3):
96+
func = relay.create_executor(
97+
mode, mod=mod, device=dev, target=target
98+
).evaluate()
99+
if run_module:
100+
result_dict[result_key] = func(**input_dict, **params)
101+
102+
if run_module:
103+
assert_result_dict_holds(result_dict)
104+
105+
106+
def test_dnnl_not_compatible(run_module):
107+
dtype = "float32"
108+
xshape = (1, 32, 14, 14)
109+
x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
110+
111+
x = relay.var("x", shape=(xshape), dtype=dtype)
112+
y = relay.add(x, x)
113+
z = relay.cast(relay.cast(y, "int32"), "float32")
114+
out = relay.nn.relu(z)
115+
f = relay.Function([x], out)
116+
mod = tvm.IRModule()
117+
mod["main"] = f
118+
mod = dnnl.partition_for_dnnl(mod)
119+
for mode in ["graph", "vm"]:
120+
with tvm.transform.PassContext(opt_level=3):
121+
func = relay.create_executor(
122+
mode, mod=mod, device=tvm.cpu(0), target="llvm"
123+
).evaluate()
124+
if run_module:
125+
results = func(x_data)
126+
127+
128+
def test_conv2d(run_module):
129+
def get_graph(
130+
x_shape=(1, 32, 8, 8),
131+
k_shape=(16, 32, 3, 3),
132+
groups=1,
133+
padding=(0, 0),
134+
strides=(1, 1),
135+
dilation=(1, 1),
136+
channels=None,
137+
):
138+
x = relay.var("x", shape=(x_shape), dtype="float32")
139+
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
140+
out = relay.nn.conv2d(
141+
x,
142+
kernel,
143+
kernel_size=k_shape[2:4],
144+
groups=groups,
145+
padding=padding,
146+
strides=strides,
147+
dilation=dilation,
148+
channels=channels,
149+
)
150+
f = relay.Function([x, kernel], out)
151+
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
152+
153+
for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]:
154+
for padding in [(0, 0), (1, 1)]:
155+
for strides in [(1, 1), (2, 2)]:
156+
for dilation in [(1, 1)]:
157+
run_and_verify_func(
158+
get_graph(
159+
k_shape=k_shape,
160+
groups=groups,
161+
padding=padding,
162+
strides=strides,
163+
dilation=dilation,
164+
),
165+
run_module=run_module,
166+
)
167+
168+
169+
def test_conv2d_weights_const(run_module):
170+
def get_graph(
171+
x_shape=(1, 32, 8, 8),
172+
k_shape=(16, 32, 3, 3),
173+
groups=1,
174+
padding=(0, 0),
175+
strides=(1, 1),
176+
dilation=(1, 1),
177+
):
178+
x = relay.var("x", shape=(x_shape), dtype="float32")
179+
kernel = relay.const(np.ones(k_shape).astype("float32"))
180+
out = relay.nn.conv2d(
181+
x,
182+
kernel,
183+
channels=k_shape[0],
184+
kernel_size=k_shape[2:4],
185+
groups=groups,
186+
padding=padding,
187+
strides=strides,
188+
dilation=dilation,
189+
)
190+
f = relay.Function([x], out)
191+
return f, {"x": x_shape}, []
192+
193+
run_and_verify_func(get_graph(), run_module=run_module)
194+
195+
196+
def test_dense(run_module):
197+
def get_graph(x_shape=(1, 16), k_shape=(32, 16)):
198+
x = relay.var("x", shape=(x_shape), dtype="float32")
199+
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
200+
out = relay.nn.dense(x, kernel, units=k_shape[0])
201+
f = relay.Function([x, kernel], out)
202+
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
203+
204+
run_and_verify_func(get_graph(), run_module=run_module)
205+
run_and_verify_func(get_graph(k_shape=(1, 16)), run_module=run_module)
206+
207+
208+
def test_multiple_outputs(run_module):
209+
def get_graph():
210+
x = relay.var("x", shape=(1, 3), dtype="float32")
211+
y = relay.var("y", shape=(1, 3), dtype="float32")
212+
z = relay.add(x, y)
213+
w = relay.add(z, y)
214+
out = relay.Tuple((z, w))
215+
f = relay.Function([x, y], out)
216+
return f, {"x": (1, 3), "y": (1, 3)}, []
217+
218+
run_and_verify_func(get_graph(), run_module=run_module)
219+
220+
221+
if __name__ == "__main__":
222+
import sys
223+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)