Skip to content

Commit 0e0adf5

Browse files
authored
[Relay] Support nchwc layout in ConvertLayout pass (#9681)
1 parent d13e2b6 commit 0e0adf5

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
1818
"""Backend compiler related feature registration"""
1919
from __future__ import absolute_import
20+
import re
2021

2122
from tvm import topi, relay
2223
from tvm.topi.utils import get_const_tuple
@@ -283,8 +284,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
283284
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
284285
assert desired_data_layout != "default", "Data layout cannot be default"
285286
new_attrs["data_layout"] = desired_data_layout
287+
need_tile = re.match(r"NCHW(\d*)c", desired_data_layout)
286288

287-
if desired_kernel_layout != "default":
289+
if desired_kernel_layout != "default" and not need_tile:
288290
new_attrs["kernel_layout"] = desired_kernel_layout
289291
return relay.nn.conv2d(data, weight, **new_attrs)
290292

@@ -309,6 +311,14 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
309311
elif desired_data_layout == "HWNC":
310312
new_attrs["kernel_layout"] = "HWOI"
311313
return relay.nn.conv2d(data, weight, **new_attrs)
314+
elif need_tile:
315+
assert desired_kernel_layout != "default", "Kernel layout cannot be default."
316+
tile = int(need_tile.group(1))
317+
if isinstance(data, relay.expr.Var) and data.checked_type.shape[1] % tile != 0:
318+
return relay.nn.conv2d(data, weight, **attrs)
319+
else:
320+
new_attrs["kernel_layout"] = desired_kernel_layout
321+
return relay.nn.contrib_conv2d_nchwc(data, weight, **new_attrs)
312322

313323
raise ValueError("Layout %s is not yet supported." % desired_data_layout)
314324

tests/python/relay/test_pass_convert_op_layout.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,5 +2323,116 @@ def expected():
23232323
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b)
23242324

23252325

2326+
@pytest.mark.parametrize(
2327+
"data_layout, kernel_layout",
2328+
[
2329+
("NCHW1c", "OIHW1i1o"),
2330+
("NCHW4c", "OIHW4i4o"),
2331+
("NCHW8c", "OIHW8i8o"),
2332+
("NCHW16c", "OIHW16i16o"),
2333+
],
2334+
)
2335+
def test_resnet_convert_layout_nchwc(data_layout, kernel_layout):
2336+
x = relay.var("x", shape=(1, 3, 224, 224))
2337+
weight1 = relay.var("weight1", shape=(64, 3, 7, 7))
2338+
weight2 = relay.var("weight2", shape=(64, 64, 3, 3))
2339+
weight3 = relay.var("weight3", shape=(64, 64, 1, 1))
2340+
2341+
def before():
2342+
y = relay.nn.conv2d(
2343+
x,
2344+
weight1,
2345+
strides=(2, 2),
2346+
padding=(3, 3),
2347+
channels=64,
2348+
kernel_size=(7, 7),
2349+
data_layout="NCHW",
2350+
kernel_layout="OIHW",
2351+
)
2352+
y = relay.nn.relu(y)
2353+
y = relay.nn.max_pool2d(y, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
2354+
y1 = relay.nn.conv2d(
2355+
y,
2356+
weight2,
2357+
channels=64,
2358+
kernel_size=(3, 3),
2359+
padding=(1, 1),
2360+
data_layout="NCHW",
2361+
kernel_layout="OIHW",
2362+
)
2363+
y1 = relay.nn.relu(y1)
2364+
y2 = relay.nn.conv2d(
2365+
y,
2366+
weight3,
2367+
channels=64,
2368+
kernel_size=(1, 1),
2369+
data_layout="NCHW",
2370+
kernel_layout="OIHW",
2371+
)
2372+
y2 = relay.nn.relu(y2)
2373+
y = y1 + y2
2374+
y = relay.nn.global_max_pool2d(y, layout="NCHW")
2375+
return y
2376+
2377+
def expected():
2378+
if data_layout == "NCHW1c":
2379+
y = relay.nn.contrib_conv2d_nchwc(
2380+
relay.layout_transform(x, "NCHW", data_layout),
2381+
relay.layout_transform(weight1, "OIHW", kernel_layout),
2382+
strides=(2, 2),
2383+
padding=(3, 3),
2384+
channels=64,
2385+
kernel_size=(7, 7),
2386+
data_layout=data_layout,
2387+
kernel_layout=kernel_layout,
2388+
)
2389+
y = relay.nn.relu(y)
2390+
y = relay.nn.max_pool2d(
2391+
y, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout
2392+
)
2393+
else:
2394+
y = relay.nn.conv2d(
2395+
x,
2396+
weight1,
2397+
strides=(2, 2),
2398+
padding=(3, 3),
2399+
channels=64,
2400+
kernel_size=(7, 7),
2401+
data_layout="NCHW",
2402+
kernel_layout="OIHW",
2403+
)
2404+
y = relay.nn.relu(y)
2405+
y = relay.nn.max_pool2d(y, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
2406+
y = relay.layout_transform(y, "NCHW", data_layout)
2407+
y1 = relay.nn.contrib_conv2d_nchwc(
2408+
y,
2409+
relay.layout_transform(weight2, "OIHW", kernel_layout),
2410+
channels=64,
2411+
kernel_size=(3, 3),
2412+
padding=(1, 1),
2413+
data_layout=data_layout,
2414+
kernel_layout=kernel_layout,
2415+
)
2416+
y1 = relay.nn.relu(y1)
2417+
y2 = relay.nn.contrib_conv2d_nchwc(
2418+
y,
2419+
relay.layout_transform(weight3, "OIHW", kernel_layout),
2420+
channels=64,
2421+
kernel_size=(1, 1),
2422+
data_layout=data_layout,
2423+
kernel_layout=kernel_layout,
2424+
)
2425+
y2 = relay.nn.relu(y2)
2426+
y = y1 + y2
2427+
y = relay.nn.global_max_pool2d(y, layout=data_layout)
2428+
y = relay.layout_transform(y, data_layout, "NCHW")
2429+
return y
2430+
2431+
a = before()
2432+
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": [data_layout, kernel_layout]}))
2433+
b = run_opt_pass(expected(), transform.InferType())
2434+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n Expect = \n" + str(b)
2435+
2436+
23262437
if __name__ == "__main__":
23272438
pytest.main([__file__])

0 commit comments

Comments
 (0)