@@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
160160 qnn_output = get_output (qnn_func , golden_inputs )
161161 np .testing .assert_equal (qnn_output , golden_output )
162162
163- def test_no_zero_point ():
163+ def no_zero_point_test ():
164164 # uint8 input
165165 data_shape = (2 , 1 , 2 , 4 )
166166 data_dtype = 'uint8'
@@ -203,7 +203,7 @@ def test_no_zero_point():
203203 verify (ref_func , qnn_func , data_shape , data_dtype ,
204204 kernel_shape , kernel_dtype )
205205
206- def test_kernel_zero_point ():
206+ def kernel_zero_point_test ():
207207 # uint8 input
208208 data_shape = (2 , 4 , 2 , 4 )
209209 data_dtype = 'uint8'
@@ -247,7 +247,7 @@ def test_kernel_zero_point():
247247 kernel_shape , kernel_dtype )
248248
249249
250- def test_input_zero_point ():
250+ def input_zero_point_test ():
251251 # uint8 input
252252 data_shape = (2 , 4 , 2 , 4 )
253253 data_dtype = 'uint8'
@@ -290,7 +290,7 @@ def test_input_zero_point():
290290 verify (ref_func , qnn_func , data_shape , data_dtype ,
291291 kernel_shape , kernel_dtype )
292292
293- def test_both_zero_point ():
293+ def both_zero_point_test ():
294294 # uint8 input
295295 data_shape = (2 , 4 , 2 , 4 )
296296 data_dtype = 'uint8'
@@ -333,7 +333,7 @@ def test_both_zero_point():
333333 verify (ref_func , qnn_func , data_shape , data_dtype ,
334334 kernel_shape , kernel_dtype )
335335
336- def test_layout ():
336+ def layout_test ():
337337 # uint8 input
338338 data_shape = (2 , 2 , 4 , 4 ) # NHWC
339339 data_dtype = 'uint8'
@@ -378,7 +378,7 @@ def test_layout():
378378
379379
380380
381- def test_padding ():
381+ def padding_test ():
382382 # uint8 input
383383 data_shape = (1 , 4 , 2 , 2 )
384384 data_dtype = 'uint8'
@@ -421,7 +421,7 @@ def test_padding():
421421 verify (ref_func , qnn_func , data_shape , data_dtype ,
422422 kernel_shape , kernel_dtype )
423423
424- def test_dilation ():
424+ def dilation_test ():
425425 # uint8 input
426426 data_shape = (2 , 4 , 4 , 4 )
427427 data_dtype = 'uint8'
@@ -444,7 +444,7 @@ def test_dilation():
444444 kernel_shape , kernel_dtype )
445445
446446
447- def test_const_folding ():
447+ def const_folding_test ():
448448 data_shape = (2 , 4 , 2 , 4 )
449449 data_dtype = 'uint8'
450450 kernel_shape = (3 , 4 , 2 , 2 )
@@ -470,7 +470,7 @@ def test_const_folding():
470470 folded_func = folded_mod ["main" ]
471471 assert "reshape" not in folded_func .astext ()
472472
473- def test_kernel_size_1x1 ():
473+ def kernel_size_1x1_test ():
474474 # uint8 input
475475 data_shape = (2 , 4 , 2 , 4 )
476476 data_dtype = 'uint8'
@@ -493,7 +493,7 @@ def test_kernel_size_1x1():
493493 verify (ref_func , qnn_func , data_shape , data_dtype ,
494494 kernel_shape , kernel_dtype )
495495
496- def test_tflite_large_irregular ():
496+ def tflite_large_irregular_test ():
497497 # uint8 input
498498 data_shape = (1 , 1024 , 1 , 1 )
499499 data_dtype = 'uint8'
@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
607607 golden_output = np .array ((124 , - 92 , 164 , - 132 )).reshape (1 , 1 , 2 , 2 )
608608 np .testing .assert_equal (qnn_output , golden_output )
609609
610- def test_broadcast_layout ():
610+ def broadcast_layout_test ():
611611 # Test broadcast support for NHWC layout.
612612 data_shape = (1 , 229 , 229 , 3 ) # NHWC
613613 data_dtype = 'uint8'
@@ -640,52 +640,17 @@ def test_broadcast_layout():
640640 with relay .build_config (opt_level = 3 ):
641641 graph , lib , params = relay .build (mod , "llvm -mcpu=skylake-avx512" )
642642
643-
644- def test_conv2d_int8 ():
645- target = "llvm -mcpu=core-avx2"
646- if not tvm .module .enabled (target ):
647- print ("skip because %s is not enabled..." % target )
648- return
649-
650- data = relay .var ("data" , shape = (1 , 28 , 28 , 128 ), dtype = 'uint8' )
651- kernel = relay .var ("w" , shape = (3 , 3 , 128 , 256 ), dtype = 'int8' )
652- conv = relay .nn .conv2d (
653- data ,
654- kernel ,
655- kernel_size = (3 , 3 ),
656- out_dtype = 'int32' ,
657- data_layout = 'NHWC' ,
658- kernel_layout = 'HWIO' )
659- func = relay .Function ([data , kernel ], conv )
660-
661- with relay .build_config (opt_level = 0 ):
662- params = {"w" : np .zeros ((3 , 3 , 128 , 256 )).astype ("int8" )}
663- # -mcpu should be specified to avoid the llvm jitting error here:
664- # https://discuss.tvm.ai/t/segfault-in-llvm/3567
665- # To use VNNI, we need to specify the micro-architecture that supports
666- # it, e.g. cascadelake.
667- graph , lib , params = relay .build (func , target , params = params )
668- mod = graph_runtime .create (graph , lib , ctx = tvm .cpu (0 ))
669- mod .set_input ("data" , np .zeros ((1 , 28 , 28 , 128 )).astype ("uint8" ))
670- mod .set_input (** params )
671- mod .run ()
672- qnn_output = mod .get_output (0 ).asnumpy ()
673- golden_output = np .zeros ((1 , 26 , 26 , 256 )).astype ("int32" )
674- np .testing .assert_equal (qnn_output , golden_output )
675-
676-
677643if __name__ == "__main__" :
678- test_no_zero_point ()
679- test_input_zero_point ()
680- test_kernel_zero_point ()
681- test_both_zero_point ()
682- test_layout ()
683- test_padding ()
684- test_dilation ()
685- test_const_folding ()
686- test_kernel_size_1x1g ()
687- test_tflite_large_irregularg ()
688- test_tflite_output_multiplier_greater_than_one ()
689- test_tflite_anistropic_strides ()
690- test_broadcast_layoutg ()
691- test_conv2d_int8 ()
644+ no_zero_point_test ()
645+ input_zero_point_test ()
646+ kernel_zero_point_test ()
647+ both_zero_point_test ()
648+ layout_test ()
649+ padding_test ()
650+ dilation_test ()
651+ const_folding_test ()
652+ kernel_size_1x1_test ()
653+ tflite_large_irregular_test ()
654+ tflite_output_multiplier_greater_than_one ()
655+ tflite_anistropic_strides ()
656+ broadcast_layout_test ()
0 commit comments