@@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
160160 qnn_output = get_output (qnn_func , golden_inputs )
161161 np .testing .assert_equal (qnn_output , golden_output )
162162
163- def no_zero_point_test ():
163+ def test_no_zero_point ():
164164 # uint8 input
165165 data_shape = (2 , 1 , 2 , 4 )
166166 data_dtype = 'uint8'
@@ -203,7 +203,7 @@ def no_zero_point_test():
203203 verify (ref_func , qnn_func , data_shape , data_dtype ,
204204 kernel_shape , kernel_dtype )
205205
206- def kernel_zero_point_test ():
206+ def test_kernel_zero_point ():
207207 # uint8 input
208208 data_shape = (2 , 4 , 2 , 4 )
209209 data_dtype = 'uint8'
@@ -247,7 +247,7 @@ def kernel_zero_point_test():
247247 kernel_shape , kernel_dtype )
248248
249249
250- def input_zero_point_test ():
250+ def test_input_zero_point ():
251251 # uint8 input
252252 data_shape = (2 , 4 , 2 , 4 )
253253 data_dtype = 'uint8'
@@ -290,7 +290,7 @@ def input_zero_point_test():
290290 verify (ref_func , qnn_func , data_shape , data_dtype ,
291291 kernel_shape , kernel_dtype )
292292
293- def both_zero_point_test ():
293+ def test_both_zero_point ():
294294 # uint8 input
295295 data_shape = (2 , 4 , 2 , 4 )
296296 data_dtype = 'uint8'
@@ -333,7 +333,7 @@ def both_zero_point_test():
333333 verify (ref_func , qnn_func , data_shape , data_dtype ,
334334 kernel_shape , kernel_dtype )
335335
336- def layout_test ():
336+ def test_layout ():
337337 # uint8 input
338338 data_shape = (2 , 2 , 4 , 4 ) # NHWC
339339 data_dtype = 'uint8'
@@ -378,7 +378,7 @@ def layout_test():
378378
379379
380380
381- def padding_test ():
381+ def test_padding ():
382382 # uint8 input
383383 data_shape = (1 , 4 , 2 , 2 )
384384 data_dtype = 'uint8'
@@ -421,7 +421,7 @@ def padding_test():
421421 verify (ref_func , qnn_func , data_shape , data_dtype ,
422422 kernel_shape , kernel_dtype )
423423
424- def dilation_test ():
424+ def test_dilation ():
425425 # uint8 input
426426 data_shape = (2 , 4 , 4 , 4 )
427427 data_dtype = 'uint8'
@@ -444,7 +444,7 @@ def dilation_test():
444444 kernel_shape , kernel_dtype )
445445
446446
447- def const_folding_test ():
447+ def test_const_folding ():
448448 data_shape = (2 , 4 , 2 , 4 )
449449 data_dtype = 'uint8'
450450 kernel_shape = (3 , 4 , 2 , 2 )
@@ -470,7 +470,7 @@ def const_folding_test():
470470 folded_func = folded_mod ["main" ]
471471 assert "reshape" not in folded_func .astext ()
472472
473- def kernel_size_1x1_test ():
473+ def test_kernel_size_1x1 ():
474474 # uint8 input
475475 data_shape = (2 , 4 , 2 , 4 )
476476 data_dtype = 'uint8'
@@ -493,7 +493,7 @@ def kernel_size_1x1_test():
493493 verify (ref_func , qnn_func , data_shape , data_dtype ,
494494 kernel_shape , kernel_dtype )
495495
496- def tflite_large_irregular_test ():
496+ def test_tflite_large_irregular ():
497497 # uint8 input
498498 data_shape = (1 , 1024 , 1 , 1 )
499499 data_dtype = 'uint8'
@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
607607 golden_output = np .array ((124 , - 92 , 164 , - 132 )).reshape (1 , 1 , 2 , 2 )
608608 np .testing .assert_equal (qnn_output , golden_output )
609609
610- def broadcast_layout_test ():
610+ def test_broadcast_layout ():
611611 # Test broadcast support for NHWC layout.
612612 data_shape = (1 , 229 , 229 , 3 ) # NHWC
613613 data_dtype = 'uint8'
@@ -640,17 +640,52 @@ def broadcast_layout_test():
640640 with relay .build_config (opt_level = 3 ):
641641 graph , lib , params = relay .build (mod , "llvm -mcpu=skylake-avx512" )
642642
643+
644+ def test_conv2d_int8 ():
645+ target = "llvm -mcpu=core-avx2"
646+ if not tvm .module .enabled (target ):
647+ print ("skip because %s is not enabled..." % target )
648+ return
649+
650+ data = relay .var ("data" , shape = (1 , 28 , 28 , 128 ), dtype = 'uint8' )
651+ kernel = relay .var ("w" , shape = (3 , 3 , 128 , 256 ), dtype = 'int8' )
652+ conv = relay .nn .conv2d (
653+ data ,
654+ kernel ,
655+ kernel_size = (3 , 3 ),
656+ out_dtype = 'int32' ,
657+ data_layout = 'NHWC' ,
658+ kernel_layout = 'HWIO' )
659+ func = relay .Function ([data , kernel ], conv )
660+
661+ with relay .build_config (opt_level = 0 ):
662+ params = {"w" : np .zeros ((3 , 3 , 128 , 256 )).astype ("int8" )}
663+ # -mcpu should be specified to avoid the llvm jitting error here:
664+ # https://discuss.tvm.ai/t/segfault-in-llvm/3567
665+ # To use VNNI, we need to specify the micro-architecture that supports
666+ # it, e.g. cascadelake.
667+ graph , lib , params = relay .build (func , target , params = params )
668+ mod = graph_runtime .create (graph , lib , ctx = tvm .cpu (0 ))
669+ mod .set_input ("data" , np .zeros ((1 , 28 , 28 , 128 )).astype ("uint8" ))
670+ mod .set_input (** params )
671+ mod .run ()
672+ qnn_output = mod .get_output (0 ).asnumpy ()
673+ golden_output = np .zeros ((1 , 26 , 26 , 256 )).astype ("int32" )
674+ np .testing .assert_equal (qnn_output , golden_output )
675+
676+
643677if __name__ == "__main__" :
644- no_zero_point_test ()
645- input_zero_point_test ()
646- kernel_zero_point_test ()
647- both_zero_point_test ()
648- layout_test ()
649- padding_test ()
650- dilation_test ()
651- const_folding_test ()
652- kernel_size_1x1_test ()
653- tflite_large_irregular_test ()
654- tflite_output_multiplier_greater_than_one ()
655- tflite_anistropic_strides ()
656- broadcast_layout_test ()
678+ test_no_zero_point ()
679+ test_input_zero_point ()
680+ test_kernel_zero_point ()
681+ test_both_zero_point ()
682+ test_layout ()
683+ test_padding ()
684+ test_dilation ()
685+ test_const_folding ()
686+ test_kernel_size_1x1g ()
687+ test_tflite_large_irregularg ()
688+ test_tflite_output_multiplier_greater_than_one ()
689+ test_tflite_anistropic_strides ()
690+ test_broadcast_layoutg ()
691+ test_conv2d_int8 ()
0 commit comments