@@ -576,57 +576,71 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
576576 assembly = lib .get_source ("asm" )
577577 return assembly
578578
579- # compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
580- target = "llvm -mcpu=skylake-avx512"
581- name = "llvm.x86.avx512.pmaddubs.w.512"
582- llvm_id = tvm .codegen .llvm_lookup_intrinsic_id (name )
583- if llvm_id != 0 :
584- fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
585- # Sweep the input channels to check int8 robustness
586- for ic in range (1 , 24 ):
587- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
588- dtypes = fast_int8_dtypes )
589- assert "pmaddubs" in asm
590-
591- for ic in range (1 , 24 ):
592- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
593- dtypes = fast_int8_dtypes )
594- assert "pmaddubs" in asm
595-
596-
597- # Sweep the output channels to check int8 robustness
598- for oc in range (2 , 24 ):
599- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
579+ def _has_fast_int8_instructions (asm , target ):
580+ if 'skylake-avx512' in target :
581+ return "pmaddubs" in asm
582+ elif 'cascadelake' in target :
583+ return "vpdpbusd" in asm
584+ else :
585+ assert False , "Target should be Skylake or Cascadelake"
586+
587+ # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
588+ targets = ["llvm -mcpu=skylake-avx512" , "llvm -mcpu=cascadelake" ]
589+ llvm_version = tvm .codegen .llvm_version_major ()
590+ for target in targets :
591+ if llvm_version >= 8 :
592+ fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
593+ # Sweep the input channels to check int8 robustness
594+ # Input channels should be a multiple of 4 internally.
595+ for ic in [1 , 4 , 6 ]:
596+ asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" ,
597+ kernel_layout = 'OIHW' ,
598+ dtypes = fast_int8_dtypes )
599+ assert _has_fast_int8_instructions (asm , target )
600+
601+ for ic in [1 , 4 , 6 ]:
602+ asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" ,
603+ kernel_layout = 'HWIO' ,
604+ dtypes = fast_int8_dtypes )
605+ assert _has_fast_int8_instructions (asm , target )
606+
607+
608+ # Sweep the output channels to check int8 robustness
609+ # Output channels should be a multiple of 16 internally.
610+ for oc in [4 , 16 , 20 ]:
611+ asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" ,
612+ kernel_layout = 'OIHW' ,
613+ dtypes = fast_int8_dtypes )
614+ assert _has_fast_int8_instructions (asm , target )
615+
616+ for oc in [4 , 16 , 20 ]:
617+ asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" ,
618+ kernel_layout = 'HWIO' ,
619+ dtypes = fast_int8_dtypes )
620+ assert _has_fast_int8_instructions (asm , target )
621+
622+ # Check that both non-divisible oc and ic work
623+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
600624 dtypes = fast_int8_dtypes )
601- assert "pmaddubs" in asm
625+ assert _has_fast_int8_instructions ( asm , target )
602626
603- for oc in range (2 , 24 ):
604- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
627+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
605628 dtypes = fast_int8_dtypes )
606- assert "pmaddubs" in asm
607-
608- # Check that both non-divisible oc and ic work
609- asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
610- dtypes = fast_int8_dtypes )
611- assert "pmaddubs" in asm
612-
613- asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
614- dtypes = fast_int8_dtypes )
615- assert "pmaddubs" in asm
616-
617- # Ensure that code is generated when datatypes are not HW supported.
618- dtypes = ('int8' , 'int8' , 'int32' )
619- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
620- dtypes = dtypes )
621- # Check that intrinisic is not present in the assembly.
622- assert "pmaddubs" not in asm
623-
624- # Ensure that code is generated when datatypes are not HW supported.
625- dtypes = ('uint8' , 'uint8' , 'int32' )
626- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
627- dtypes = dtypes )
628- # Check that intrinisic is not present in the assembly.
629- assert "pmaddubs" not in asm
629+ assert _has_fast_int8_instructions (asm , target )
630+
631+ # Ensure that code is generated when datatypes are not HW supported.
632+ dtypes = ('int8' , 'int8' , 'int32' )
633+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
634+ dtypes = dtypes )
635+ # Check that intrinisic is not present in the assembly.
636+ assert not _has_fast_int8_instructions (asm , target )
637+
638+ # Ensure that code is generated when datatypes are not HW supported.
639+ dtypes = ('uint8' , 'uint8' , 'int32' )
640+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
641+ dtypes = dtypes )
642+ # Check that intrinisic is not present in the assembly.
643+ assert not _has_fast_int8_instructions (asm , target )
630644
631645 # Check that a vectorized instruction is generated for older Intel
632646 # generations, because we default to NCHWc layout.
0 commit comments