@@ -576,66 +576,84 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
576576        assembly  =  lib .get_source ("asm" )
577577        return  assembly 
578578
579-     # compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions 
580-     target  =  "llvm -mcpu=skylake-avx512" 
581-     name  =  "llvm.x86.avx512.pmaddubs.w.512" 
582-     llvm_id  =  tvm .codegen .llvm_lookup_intrinsic_id (name )
583-     if  llvm_id  !=  0 :
584-         fast_int8_dtypes  =  ('uint8' , 'int8' , 'int32' )
585-         # Sweep the input channels to check int8 robustness 
586-         for  ic  in  range (1 , 24 ):
587-             asm  =  _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
588-                            dtypes = fast_int8_dtypes )
589-             assert  "pmaddubs"  in  asm 
590- 
591-         for  ic  in  range (1 , 24 ):
592-             asm  =  _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
593-                            dtypes = fast_int8_dtypes )
594-             assert  "pmaddubs"  in  asm 
595- 
596- 
597-         # Sweep the output channels to check int8 robustness 
598-         for  oc  in  range (2 , 24 ):
599-             asm  =  _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
579+     def  has_fast_int8_instruction (asm , target ):
580+         intel_device_type  =  None 
581+         if  'skylake-avx512'  in  target :
582+             return  "pmaddubs"  in  asm 
583+         elif  'cascadelake'  in  target :
584+             return  "vpdpbusd"  in  asm 
585+         else :
586+             assert  False , "Target should be Skylake or Cascadelake" 
587+ 
588+     # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions 
589+     targets  =  ["llvm -mcpu=skylake-avx512" , "llvm -mcpu=cascadelake" ]
590+     name_skylake  =  "llvm.x86.avx512.pmaddubs.w.512" 
591+     name_cascadelake  =  'llvm.x86.avx512.vpdpbusd.512' 
592+     llvm_id_skylake  =  tvm .codegen .llvm_lookup_intrinsic_id (name_skylake )
593+     llvm_id_cascadelake  =  tvm .codegen .llvm_lookup_intrinsic_id (name_cascadelake )
594+     for  target  in  targets :
595+         if  llvm_id_skylake  !=  0  and  llvm_id_cascadelake :
596+             fast_int8_dtypes  =  ('uint8' , 'int8' , 'int32' )
597+             # Sweep the input channels to check int8 robustness 
598+             # Input channels should be a multiple of 4 internally. 
599+             for  ic  in  [1 , 4 , 6 ]:
600+                 asm  =  _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" ,
601+                                kernel_layout = 'OIHW' ,
602+                                dtypes = fast_int8_dtypes )
603+                 assert  has_fast_int8_instruction (asm , target )
604+ 
605+             for  ic  in  [1 , 4 , 6 ]:
606+                 asm  =  _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" ,
607+                                kernel_layout = 'HWIO' ,
608+                                dtypes = fast_int8_dtypes )
609+                 assert  has_fast_int8_instruction (asm , target )
610+ 
611+ 
612+             # Sweep the output channels to check int8 robustness 
613+             # Output channels should be a multiple of 16 internally. 
614+             for  oc  in  [4 , 16 , 20 ]:
615+                 asm  =  _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" ,
616+                                kernel_layout = 'OIHW' ,
617+                                dtypes = fast_int8_dtypes )
618+                 assert  has_fast_int8_instruction (asm , target )
619+ 
620+             for  oc  in  [4 , 16 , 20 ]:
621+                 asm  =  _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" ,
622+                                kernel_layout = 'HWIO' ,
623+                                dtypes = fast_int8_dtypes )
624+                 assert  has_fast_int8_instruction (asm , target )
625+ 
626+             # Check that both non-divisible oc and ic work 
627+             asm  =  _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
600628                           dtypes = fast_int8_dtypes )
601-             assert  "pmaddubs"   in   asm 
629+             assert  has_fast_int8_instruction ( asm ,  target ) 
602630
603-         for  oc  in  range (2 , 24 ):
604-             asm  =  _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
631+             asm  =  _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
605632                           dtypes = fast_int8_dtypes )
606-             assert  "pmaddubs"  in  asm 
607- 
608-         # Check that both non-divisible oc and ic work 
609-         asm  =  _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
610-                        dtypes = fast_int8_dtypes )
611-         assert  "pmaddubs"  in  asm 
612- 
613-         asm  =  _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
633+             assert  has_fast_int8_instruction (asm , target )
634+ 
635+             # Ensure that code is generated when datatypes are not HW supported. 
636+             dtypes  =  ('int8' , 'int8' , 'int32' )
637+             asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
638+                            dtypes = dtypes )
639+             # Check that intrinisic is not present in the assembly. 
640+             assert  not  has_fast_int8_instruction (asm , target )
641+ 
642+             # Ensure that code is generated when datatypes are not HW supported. 
643+             dtypes  =  ('uint8' , 'uint8' , 'int32' )
644+             asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
645+                            dtypes = dtypes )
646+             # Check that intrinisic is not present in the assembly. 
647+             assert  not  has_fast_int8_instruction (asm , target )
648+ 
649+         # Check that a vectorized instruction is generated for older Intel 
650+         # generations, because we default to NCHWc layout. 
651+         target  =  "llvm -mcpu=core-avx2" 
652+         fast_int8_dtypes  =  ('uint8' , 'int8' , 'int32' )
653+         asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
614654                       dtypes = fast_int8_dtypes )
615-         assert  "pmaddubs"  in  asm 
616- 
617-         # Ensure that code is generated when datatypes are not HW supported. 
618-         dtypes  =  ('int8' , 'int8' , 'int32' )
619-         asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
620-                        dtypes = dtypes )
621-         # Check that intrinisic is not present in the assembly. 
622-         assert  "pmaddubs"  not  in   asm 
623- 
624-         # Ensure that code is generated when datatypes are not HW supported. 
625-         dtypes  =  ('uint8' , 'uint8' , 'int32' )
626-         asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
627-                        dtypes = dtypes )
628-         # Check that intrinisic is not present in the assembly. 
629-         assert  "pmaddubs"  not  in   asm 
630- 
631-     # Check that a vectorized instruction is generated for older Intel 
632-     # generations, because we default to NCHWc layout. 
633-     target  =  "llvm -mcpu=core-avx2" 
634-     fast_int8_dtypes  =  ('uint8' , 'int8' , 'int32' )
635-     asm  =  _compile (ic = 16 , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
636-                    dtypes = fast_int8_dtypes )
637-     # Check that vector int mult and add instructions are generated. 
638-     assert  "vpmulld"  in  asm  and  "vpadd"  in  asm 
655+         # Check that vector int mult and add instructions are generated. 
656+         assert  "vpmulld"  in  asm  and  "vpadd"  in  asm 
639657
640658
641659def  test_bitserial_conv2d_infer_type ():
0 commit comments