Skip to content

Commit 8846f31

Browse files
authored
Add FlattenAtrousConv pass into the default optimize pipeline. (#11077)
1 parent 7612b22 commit 8846f31

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

include/tvm/relay/transform.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,15 @@ TVM_DLL Pass ManifestLifetimes();
494494
*/
495495
TVM_DLL Pass PlanDevices(CompilationConfig config);
496496

497+
/*!
498+
* \brief This transform flattens atrous convolution, which corresponds to the sequence of
499+
* operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs
500+
* with a convolution with the modified "dilation" and recalculated "padding" parameters.
501+
*
502+
* \return The pass.
503+
*/
504+
TVM_DLL Pass FlattenAtrousConv();
505+
497506
} // namespace transform
498507

499508
/*!

src/relay/backend/utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ Array<Pass> GetPassPrefix(bool is_homegeneous, bool is_vm) {
262262
// Fast math optimizations.
263263
pass_seqs.push_back(transform::FastMath());
264264
pass_seqs.push_back(transform::FoldConstant());
265+
266+
pass_seqs.push_back(transform::FlattenAtrousConv());
265267
return pass_seqs;
266268
}
267269

tests/python/relay/test_pass_flatten_atrous_conv.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import tvm
2121
from tvm import relay
22+
from tvm.contrib import graph_executor
2223

2324

2425
def compare_expected_fac(expr, expected_expr, args):
@@ -421,6 +422,50 @@ def test_fac_op_btwn_conv_b2s():
421422
compare_expected_fac(expr, expected_expr, [x_np])
422423

423424

425+
def test_fac_relay_build():
426+
# Check the default optimize pipeline
427+
shape_x = [1, 5, 5, 4]
428+
shape_w = [3, 3, 4, 1]
429+
430+
x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
431+
w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
432+
433+
weight = relay.const(w_np)
434+
data = relay.var("data", shape=shape_x, dtype="float32")
435+
op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
436+
op2 = relay.nn.conv2d(
437+
op1,
438+
weight,
439+
padding=[0, 0, 0, 0],
440+
groups=4,
441+
channels=4,
442+
kernel_size=[3, 3],
443+
data_layout="NHWC",
444+
kernel_layout="HWOI",
445+
)
446+
expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
447+
448+
mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr))
449+
result_def = (
450+
relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm")
451+
.evaluate()(x_np)
452+
.numpy()
453+
)
454+
455+
graph, lib, params = relay.build(mod_def, "llvm", params=None)
456+
rt_mod = graph_executor.create(graph, lib, device=tvm.cpu())
457+
rt_mod.set_input("data", x_np)
458+
rt_mod.set_input(**params)
459+
rt_mod.run()
460+
result_flat = rt_mod.get_output(0).numpy()
461+
462+
assert "space_to_batch_nd" not in graph
463+
assert "conv2d" in graph
464+
assert "batch_to_space_nd" not in graph
465+
466+
assert np.array_equal(result_def, result_flat)
467+
468+
424469
if __name__ == "__main__":
425470
import sys
426471

0 commit comments

Comments
 (0)