|
19 | 19 | import pytest |
20 | 20 | import tvm |
21 | 21 | from tvm import relay |
| 22 | +from tvm.contrib import graph_executor |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def compare_expected_fac(expr, expected_expr, args): |
@@ -421,6 +422,50 @@ def test_fac_op_btwn_conv_b2s(): |
421 | 422 | compare_expected_fac(expr, expected_expr, [x_np]) |
422 | 423 |
|
423 | 424 |
|
| 425 | +def test_fac_relay_build(): |
| 426 | + # Check the default optimize pipeline |
| 427 | + shape_x = [1, 5, 5, 4] |
| 428 | + shape_w = [3, 3, 4, 1] |
| 429 | + |
| 430 | + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") |
| 431 | + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") |
| 432 | + |
| 433 | + weight = relay.const(w_np) |
| 434 | + data = relay.var("data", shape=shape_x, dtype="float32") |
| 435 | + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) |
| 436 | + op2 = relay.nn.conv2d( |
| 437 | + op1, |
| 438 | + weight, |
| 439 | + padding=[0, 0, 0, 0], |
| 440 | + groups=4, |
| 441 | + channels=4, |
| 442 | + kernel_size=[3, 3], |
| 443 | + data_layout="NHWC", |
| 444 | + kernel_layout="HWOI", |
| 445 | + ) |
| 446 | + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) |
| 447 | + |
| 448 | + mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr)) |
| 449 | + result_def = ( |
| 450 | + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") |
| 451 | + .evaluate()(x_np) |
| 452 | + .numpy() |
| 453 | + ) |
| 454 | + |
| 455 | + graph, lib, params = relay.build(mod_def, "llvm", params=None) |
| 456 | + rt_mod = graph_executor.create(graph, lib, device=tvm.cpu()) |
| 457 | + rt_mod.set_input("data", x_np) |
| 458 | + rt_mod.set_input(**params) |
| 459 | + rt_mod.run() |
| 460 | + result_flat = rt_mod.get_output(0).numpy() |
| 461 | + |
| 462 | + assert "space_to_batch_nd" not in graph |
| 463 | + assert "conv2d" in graph |
| 464 | + assert "batch_to_space_nd" not in graph |
| 465 | + |
| 466 | + assert np.array_equal(result_def, result_flat) |
| 467 | + |
| 468 | + |
424 | 469 | if __name__ == "__main__": |
425 | 470 | import sys |
426 | 471 |
|
|
0 commit comments