Skip to content

Commit d69ca11

Browse files
fix
1 parent 1e8ca97 commit d69ca11

File tree

4 files changed

+28
-31
lines changed

4 files changed

+28
-31
lines changed

python/tvm/relay/testing/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,17 @@
3232
from .config import ctx_list
3333
from .init import create_workload
3434
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
35+
import tvm.relay as relay
36+
from tvm.relay import transform
37+
38+
39+
def run_opt_pass(expr, opt_pass):
40+
assert isinstance(opt_pass, transform.Pass)
41+
mod = relay.Module.from_expr(expr)
42+
mod = opt_pass(mod)
43+
entry = mod[mod.entry_func]
44+
return entry if isinstance(expr, relay.Function) else entry.body
45+
46+
47+
def run_infer_type(expr):
48+
return run_opt_pass(expr, transform.InferType())

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
from tvm.relay import transform
2020

2121

22-
def run_opt_pass(expr, opt_pass):
23-
assert isinstance(opt_pass, transform.Pass)
24-
mod = relay.Module.from_expr(expr)
25-
mod = opt_pass(mod)
26-
entry = mod[mod.entry_func]
27-
return entry if isinstance(expr, relay.Function) else entry.body
28-
29-
3022
def test_fuse_simple():
3123
"""Simple testcase."""
3224
def before():

tests/python/relay/test_pass_gradient.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,7 @@
2222
from tvm.relay import create_executor, transform
2323
from tvm.relay.transform import gradient
2424
from tvm.relay.prelude import Prelude
25-
from tvm.relay.testing import add_nat_definitions, make_nat_expr
26-
27-
28-
def run_infer_type(expr):
29-
mod = relay.Module.from_expr(expr)
30-
mod = relay.Module.from_expr(expr)
31-
mod = transform.InferType()(mod)
32-
entry = mod[mod.entry_func]
33-
return entry if isinstance(expr, relay.Function) else entry.body
25+
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
3426

3527

3628
def rand(dtype='float32', *shape):

tests/python/relay/test_pass_to_cps.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import numpy as np
1818
import tvm
1919
from tvm import relay
20-
from tvm.relay.ir_pass import alpha_equal, infer_type, detect_feature
21-
from tvm.relay.ir_pass import to_cps, un_cps
20+
from tvm.relay.analysis import alpha_equal, detect_feature
21+
from tvm.relay.transform import to_cps, un_cps
2222
from tvm.relay.feature import Feature
2323
from tvm.relay.prelude import Prelude
24-
from tvm.relay.testing import add_nat_definitions, make_nat_expr
24+
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
2525
from tvm.relay import create_executor
2626
from tvm.relay import Function, transform
2727

@@ -42,13 +42,12 @@ def test_recursion():
4242
double = relay.Function([x], x + x)
4343
i = relay.var("i", t)
4444
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
45-
func = infer_type(func, mod=mod)
46-
cps_func = infer_type(un_cps(infer_type(to_cps(func, mod=mod), mod=mod)), mod=mod)
47-
print(mod)
48-
print(cps_func)
45+
mod[mod.entry_func] = func
46+
mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod)
47+
mod[mod.entry_func] = un_cps(mod[mod.entry_func])
4948
ex = create_executor(mod=mod)
5049
i_nd = rand(dtype, *shape)
51-
forward = ex.evaluate(cps_func)(i_nd)
50+
forward = ex.evaluate(mod.entry_func)(i_nd)
5251
tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
5352

5453

@@ -57,12 +56,12 @@ def test_recursion():
5756
# cps and pe can completely eliminate the allocation of reference.
5857
def test_cps_pe():
5958
def destroy_ref(x):
60-
x = infer_type(x)
59+
x = run_infer_type(x)
6160
x = to_cps(x)
62-
x = infer_type(x)
61+
x = run_infer_type(x)
6362
y = un_cps(x)
64-
y = infer_type(y)
65-
x = transform.OptimizeOnExpr(x, [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])
63+
y = run_infer_type(y)
64+
x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
6665
assert Feature.fRefCreate not in detect_feature(x)
6766
unit = relay.Function([], relay.const(0., dtype='float32'))
6867
f_ref = relay.Var("f_ref")
@@ -82,7 +81,7 @@ def destroy_ref(x):
8281
destroy_ref(F)
8382

8483
G = relay.Function([cond], relay.If(cond, one, two))
85-
G = relay.ir_pass.gradient(G)
84+
G = relay.transform.gradient(G)
8685
destroy_ref(G)
8786

8887
x = relay.var("x", shape=(1, 16))
@@ -92,7 +91,7 @@ def destroy_ref(x):
9291
H = relay.If(cond, x, y)
9392
H = relay.add(H, z)
9493
H = relay.Function([cond,x,y,z], H)
95-
H = relay.ir_pass.gradient(H)
94+
H = relay.transform.gradient(H)
9695
destroy_ref(H)
9796

9897

0 commit comments

Comments
 (0)