2121from tvm .script import relax as R
2222
2323
24- def main ():
24+ def _iter_binding_names (mod ):
25+ """Helper function to compare the names of relax variables"""
26+ for block in mod ["forward" ].body .blocks :
27+ for binding in block .bindings :
28+ yield binding .var .name_hint
29+
30+
31+ def test_nn_export_to_relax ():
2532 class TestModule (nn .Module ):
2633 def __init__ (self , in_features : int , out_features : int ):
2734 super ().__init__ ()
@@ -35,39 +42,28 @@ def forward(self, x: nn.Tensor):
3542 x2 = self .linear_2 (x )
3643 return x1 + x2
3744
38- # pylint: disable=line-too-long
3945 @I .ir_module
40- class ExpectedModule : # pylint: disable=too-few-public-methods
46+ class ExpectedModule :
4147 @R .function
4248 def forward (
4349 x : R .Tensor ((1 , 10 ), dtype = "float32" ),
4450 packed_params : R .Tuple (
4551 R .Tensor ((20 , 10 ), dtype = "float32" ), R .Tensor ((20 , 10 ), dtype = "float32" )
4652 ),
47- ) -> R .Tensor ((1 , 20 ), dtype = "float32" ):
48- R .func_attr ({"num_input" : 1 }) # type: ignore[attr-defined]
49- with R .dataflow (): # type: ignore[attr-defined]
50- linear_1_weight : R .Tensor ((20 , 10 ), dtype = "float32" ) = packed_params [0 ] # type: ignore[valid-type]
51- linear_2_weight : R .Tensor ((20 , 10 ), dtype = "float32" ) = packed_params [1 ] # type: ignore[valid-type]
52- permute_dims : R .Tensor ((10 , 20 ), dtype = "float32" ) = R .permute_dims ( # type: ignore[attr-defined,valid-type]
53- linear_1_weight , axes = None
54- )
55- matmul : R .Tensor ((1 , 20 ), dtype = "float32" ) = R .matmul ( # type: ignore[attr-defined,valid-type]
56- x , permute_dims , out_dtype = "void"
57- )
58- permute_dims1 : R .Tensor ((10 , 20 ), dtype = "float32" ) = R .permute_dims ( # type: ignore[attr-defined,valid-type]
59- linear_2_weight , axes = None
60- )
61- matmul1 : R .Tensor ((1 , 20 ), dtype = "float32" ) = R .matmul ( # type: ignore[attr-defined,valid-type]
62- x , permute_dims1 , out_dtype = "void"
63- )
64- add : R .Tensor ((1 , 20 ), dtype = "float32" ) = R .add (matmul , matmul1 ) # type: ignore[attr-defined,valid-type]
65- gv : R .Tensor ((1 , 20 ), dtype = "float32" ) = add # type: ignore[attr-defined,valid-type]
66- R .output (gv ) # type: ignore[attr-defined,valid-type]
53+ ):
54+ R .func_attr ({"num_input" : 1 })
55+ with R .dataflow ():
56+ linear_1_weight = packed_params [0 ]
57+ linear_2_weight = packed_params [1 ]
58+ matmul_1_weight = R .permute_dims (linear_1_weight )
59+ matmul = R .matmul (x , matmul_1_weight )
60+ matmul_2_weight = R .permute_dims (linear_2_weight )
61+ matmul1 = R .matmul (x , matmul_2_weight )
62+ add = R .add (matmul , matmul1 )
63+ gv = add
64+ R .output (gv )
6765 return gv
6866
69- # pylint: enable=line-too-long
70-
7167 model = TestModule (10 , 20 )
7268 mod , _ = model .export_tvm (
7369 spec = {
@@ -82,6 +78,9 @@ def forward(
8278 )
8379 tvm .ir .assert_structural_equal (mod , ExpectedModule )
8480
81+ for name , expected_name in zip (_iter_binding_names (mod ), _iter_binding_names (ExpectedModule )):
82+ assert name == expected_name
83+
8584
8685if __name__ == "__main__" :
87- main ()
86+ tvm . testing . main ()
0 commit comments