1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17-
17+ """AOT with C++ Runtime Tests"""
1818
1919import re
20- import sys
2120import textwrap
2221
2322import numpy as np
2827from tvm import relay
2928from tvm .relay import backend , testing
3029from tvm .testing .aot import generate_ref_data
31- from tvm .micro .testing .aot_test_utils import AOT_DEFAULT_RUNNER
3230
3331
3432def test_error_c_interface ():
35- interface_api = "c"
36- use_unpacked_api = False
37- test_runner = AOT_DEFAULT_RUNNER
33+ """Checks that an error occurs when using the packed API in combination with C interface"""
3834
3935 two = relay .add (relay .const (1 ), relay .const (1 ))
4036 func = relay .Function ([], two )
@@ -53,12 +49,11 @@ def test_error_c_interface():
5349 )
5450
5551
56- enable_usmp = tvm .testing .parameter (True , False )
57- target_kind = tvm .testing .parameter ("c" , "llvm" )
58-
59-
52+ @pytest .mark .parametrize ("enable_usmp" , [True , False ])
53+ @pytest .mark .parametrize ("target_kind" , ["c" , "llvm" ])
6054def test_conv2d (enable_usmp , target_kind ):
61- RELAY_MODEL = textwrap .dedent (
55+ """Tests compilation of convolutions"""
56+ relay_model = textwrap .dedent (
6257 """\
6358 #[version = "0.0.5"]
6459 def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) {
@@ -86,7 +81,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
8681 }
8782 """
8883 )
89- ir_mod = tvm .parser .fromtext (RELAY_MODEL )
84+ ir_mod = tvm .parser .fromtext (relay_model )
9085
9186 main_func = ir_mod ["main" ]
9287 shape_dict = {p .name_hint : p .checked_type .concrete_shape for p in main_func .params }
@@ -119,7 +114,10 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
119114 assert (runner .get_output (0 ).asnumpy () == list (ref_outputs .values ())[0 ]).all ()
120115
121116
117+ @pytest .mark .parametrize ("enable_usmp" , [True , False ])
118+ @pytest .mark .parametrize ("target_kind" , ["c" , "llvm" ])
122119def test_mobilenet (enable_usmp , target_kind ):
120+ """Full network test with Mobilenet"""
123121 ir_mod , params = testing .mobilenet .get_workload (batch_size = 1 )
124122 data_shape = [int (x ) for x in ir_mod ["main" ].checked_type .arg_types [0 ].shape ]
125123 data = np .random .uniform (size = data_shape ).astype ("float32" )
@@ -147,10 +145,11 @@ def test_mobilenet(enable_usmp, target_kind):
147145
148146
149147def test_module_list ():
150- x = tvm .relay .var ("x" , tvm .relay .TensorType ([1 ], dtype = "float32" ))
151- expr = tvm .relay .add (x , tvm .relay .Constant (tvm .nd .array (np .array ([1 ], dtype = "float32" ))))
148+ """Checks the correct list of module names is generated"""
149+ input_x = tvm .relay .var ("x" , tvm .relay .TensorType ([1 ], dtype = "float32" ))
150+ expr = tvm .relay .add (input_x , tvm .relay .Constant (tvm .nd .array (np .array ([1 ], dtype = "float32" ))))
152151 mod = tvm .relay .build (
153- tvm .IRModule .from_expr (tvm .relay .Function ([x ], expr )),
152+ tvm .IRModule .from_expr (tvm .relay .Function ([input_x ], expr )),
154153 target = "c" ,
155154 executor = tvm .relay .backend .Executor ("aot" , {"interface-api" : "packed" }),
156155 mod_name = "unusual_module_name_fred" ,
@@ -177,6 +176,7 @@ def test_create_executor():
177176
178177
179178def test_pass_wrong_device_arg ():
179+ """Ensure an error is generated if the incorrect number of devices are passed"""
180180 x = tvm .relay .var ("x" , tvm .relay .TensorType ([1 ], dtype = "float32" ))
181181 expr = tvm .relay .add (x , tvm .relay .Constant (tvm .nd .array (np .array ([1 ], dtype = "float32" ))))
182182 with tvm .transform .PassContext (opt_level = 3 , config = {"tir.disable_vectorize" : True }):
@@ -191,12 +191,12 @@ def test_pass_wrong_device_arg():
191191 mod .export_library (test_so_path , cc = "gcc" , options = ["-std=c11" ])
192192 loaded_mod = tvm .runtime .load_module (test_so_path )
193193
194- with pytest .raises (tvm .TVMError ) as cm :
194+ with pytest .raises (tvm .TVMError ) as error :
195195 tvm .runtime .executor .AotModule (loaded_mod ["default" ](tvm .cpu (0 ), tvm .cpu (0 )))
196196
197197 assert (
198198 "Check failed: devices_.size() == 1 (2 vs. 1) : Expect exactly 1 device passed."
199- in str (cm .exception )
199+ in str (error .exception )
200200 )
201201 # TODO write asserts for # and type of device.
202202
0 commit comments