1717import numpy as np
1818import tvm
1919from tvm import relay
20- from tvm .relay .ir_pass import alpha_equal , infer_type , detect_feature
21- from tvm .relay .ir_pass import to_cps , un_cps
20+ from tvm .relay .analysis import alpha_equal , detect_feature
21+ from tvm .relay .transform import to_cps , un_cps
2222from tvm .relay .feature import Feature
2323from tvm .relay .prelude import Prelude
24- from tvm .relay .testing import add_nat_definitions , make_nat_expr
24+ from tvm .relay .testing import add_nat_definitions , make_nat_expr , run_infer_type , run_opt_pass
2525from tvm .relay import create_executor
2626from tvm .relay import Function , transform
2727
@@ -42,13 +42,12 @@ def test_recursion():
4242 double = relay .Function ([x ], x + x )
4343 i = relay .var ("i" , t )
4444 func = relay .Function ([i ], p .nat_iterate (double , make_nat_expr (p , 3 ))(i ))
45- func = infer_type (func , mod = mod )
46- cps_func = infer_type (un_cps (infer_type (to_cps (func , mod = mod ), mod = mod )), mod = mod )
47- print (mod )
48- print (cps_func )
45+ mod [mod .entry_func ] = func
46+ mod [mod .entry_func ] = to_cps (mod [mod .entry_func ], mod = mod )
47+ mod [mod .entry_func ] = un_cps (mod [mod .entry_func ])
4948 ex = create_executor (mod = mod )
5049 i_nd = rand (dtype , * shape )
51- forward = ex .evaluate (cps_func )(i_nd )
50+ forward = ex .evaluate (mod . entry_func )(i_nd )
5251 tvm .testing .assert_allclose (forward .asnumpy (), 8 * i_nd .asnumpy ())
5352
5453
@@ -57,12 +56,12 @@ def test_recursion():
5756# cps and pe can completely eliminate the allocation of reference.
5857def test_cps_pe ():
5958 def destroy_ref (x ):
60- x = infer_type (x )
59+ x = run_infer_type (x )
6160 x = to_cps (x )
62- x = infer_type (x )
61+ x = run_infer_type (x )
6362 y = un_cps (x )
64- y = infer_type (y )
65- x = transform . OptimizeOnExpr (x , [transform .PartialEvaluate (), transform .DeadCodeElimination (inline_once = True )])
63+ y = run_infer_type (y )
64+ x = run_opt_pass (x , transform . Sequential ( [transform .PartialEvaluate (), transform .DeadCodeElimination (inline_once = True )]) )
6665 assert Feature .fRefCreate not in detect_feature (x )
6766 unit = relay .Function ([], relay .const (0. , dtype = 'float32' ))
6867 f_ref = relay .Var ("f_ref" )
@@ -82,7 +81,7 @@ def destroy_ref(x):
8281 destroy_ref (F )
8382
8483 G = relay .Function ([cond ], relay .If (cond , one , two ))
85- G = relay .ir_pass .gradient (G )
84+ G = relay .transform .gradient (G )
8685 destroy_ref (G )
8786
8887 x = relay .var ("x" , shape = (1 , 16 ))
@@ -92,7 +91,7 @@ def destroy_ref(x):
9291 H = relay .If (cond , x , y )
9392 H = relay .add (H , z )
9493 H = relay .Function ([cond ,x ,y ,z ], H )
95- H = relay .ir_pass .gradient (H )
94+ H = relay .transform .gradient (H )
9695 destroy_ref (H )
9796
9897
0 commit comments