@@ -546,9 +546,11 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
546546
547547 n , h , w , ch , cw = 1 , 64 , 64 , 3 , 3
548548 if data_layout == 'NCHW' :
549- x = relay .var ("x" , relay .TensorType ((n , ic , h , w ), input_dtype ))
549+ data_shape = (n , ic , h , w )
550+ x = relay .var ("x" , relay .TensorType (data_shape , input_dtype ))
550551 elif data_layout == 'NHWC' :
551- x = relay .var ("x" , relay .TensorType ((n , h , w , ic ), input_dtype ))
552+ data_shape = (n , h , w , ic )
553+ x = relay .var ("x" , relay .TensorType (data_shape , input_dtype ))
552554 else :
553555 raise ValueError ('Not supported' )
554556
@@ -559,20 +561,22 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
559561 else :
560562 raise ValueError ('Not supported' )
561563
562- w = relay .var ("w " , relay .TensorType (kernel_shape , weight_dtype ))
563- y = relay .nn .conv2d (x , w ,
564+ weight = relay .var ("weight " , relay .TensorType (kernel_shape , weight_dtype ))
565+ y = relay .nn .conv2d (x , weight ,
564566 kernel_size = (ch , cw ),
565567 channels = oc ,
566568 padding = (1 , 1 ),
567569 dilation = (1 , 1 ),
568570 data_layout = data_layout ,
569571 kernel_layout = kernel_layout ,
570572 out_dtype = output_dtype )
571- func = relay .Function ([x , w ], y )
573+ func = relay .Function ([x , weight ], y )
572574 wdata = np .random .rand (* kernel_shape ) * 10
573- parameters = {"w" : tvm .nd .array (wdata .astype (weight_dtype ))}
575+ parameters = {"weight" : tvm .nd .array (wdata .astype (weight_dtype ))}
576+
574577 with relay .build_config (opt_level = 3 ):
575578 graph , lib , params = relay .build (func , target , params = parameters )
579+
576580 assembly = lib .get_source ("asm" )
577581 return assembly
578582
@@ -589,58 +593,63 @@ def _has_fast_int8_instructions(asm, target):
589593 llvm_version = tvm .codegen .llvm_version_major ()
590594 for target in targets :
591595 if llvm_version >= 8 :
592- fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
596+ dtypes = ('uint8' , 'int8' , 'int32' )
593597 # Sweep the input channels to check int8 robustness
594598 # Input channels should be a multiple of 4 internally.
595599 for ic in [1 , 4 , 6 ]:
596- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" ,
600+ asm = _compile (ic = ic , oc = 16 , target = target , data_layout = "NCHW" ,
597601 kernel_layout = 'OIHW' ,
598- dtypes = fast_int8_dtypes )
602+ dtypes = dtypes )
599603 assert _has_fast_int8_instructions (asm , target )
600604
601605 for ic in [1 , 4 , 6 ]:
602- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" ,
606+ asm = _compile (ic = ic , oc = 16 , target = target , data_layout = "NHWC" ,
603607 kernel_layout = 'HWIO' ,
604- dtypes = fast_int8_dtypes )
608+ dtypes = dtypes )
605609 assert _has_fast_int8_instructions (asm , target )
606610
607-
608611 # Sweep the output channels to check int8 robustness
609612 # Output channels should be a multiple of 16 internally.
610613 for oc in [4 , 16 , 20 ]:
611- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" ,
614+ asm = _compile (ic = 8 , oc = oc , target = target , data_layout = "NCHW" ,
612615 kernel_layout = 'OIHW' ,
613- dtypes = fast_int8_dtypes )
616+ dtypes = dtypes )
614617 assert _has_fast_int8_instructions (asm , target )
615618
616619 for oc in [4 , 16 , 20 ]:
617- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" ,
620+ asm = _compile (ic = 8 , oc = oc , target = target , data_layout = "NHWC" ,
618621 kernel_layout = 'HWIO' ,
619- dtypes = fast_int8_dtypes )
622+ dtypes = dtypes )
620623 assert _has_fast_int8_instructions (asm , target )
621624
622625 # Check that both non-divisible oc and ic work
623626 asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
624- dtypes = fast_int8_dtypes )
627+ dtypes = dtypes )
625628 assert _has_fast_int8_instructions (asm , target )
626629
627630 asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
628- dtypes = fast_int8_dtypes )
631+ dtypes = dtypes )
629632 assert _has_fast_int8_instructions (asm , target )
630633
631- # Ensure that code is generated when datatypes are not HW supported.
632- dtypes = ('int8' , 'int8' , 'int32' )
633- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
634+ # Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
635+ for target in targets :
636+ if llvm_version >= 8 :
637+ dtypes = (('int8' , 'int8' , 'int32' ))
638+ # Check that both non-divisible oc and ic work
639+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
634640 dtypes = dtypes )
635- # Check that intrinisic is not present in the assembly.
636- assert not _has_fast_int8_instructions (asm , target )
641+ assert _has_fast_int8_instructions (asm , target )
637642
638- # Ensure that code is generated when datatypes are not HW supported.
639- dtypes = ('uint8' , 'uint8' , 'int32' )
640- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
643+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
641644 dtypes = dtypes )
642- # Check that intrinisic is not present in the assembly.
643- assert not _has_fast_int8_instructions (asm , target )
645+ assert _has_fast_int8_instructions (asm , target )
646+
647+ # Ensure that code is generated when datatypes are not HW supported.
648+ dtypes = ('uint8' , 'uint8' , 'int32' )
649+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
650+ dtypes = dtypes )
651+ # Check that intrinisic is not present in the assembly.
652+ assert not _has_fast_int8_instructions (asm , target )
644653
645654 # Check that a vectorized instruction is generated for older Intel
646655 # generations, because we default to NCHWc layout.
0 commit comments