Skip to content

Commit 29070d0

Browse files
committed
[TOPI][x86] Cascade lake support.
1 parent bfb811c commit 29070d0

File tree

8 files changed

+119
-80
lines changed

8 files changed

+119
-80
lines changed

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _is_int8_hw_support(target):
100100
Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101101
and above.
102102
"""
103-
supported_arches = {'-mcpu=skylake-avx512',}
103+
supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
104104
return supported_arches.intersection(set(target.options))
105105

106106
# Collect the dtypes.

tests/python/contrib/test_gemm_acc16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
1818
import tvm
1919
import numpy as np
20-
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int16
20+
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int16
2121

2222

2323
def benchmark_fc_int8_acc16():
@@ -40,7 +40,7 @@ def verify(target="llvm -mcpu=skylake-avx512"):
4040
ctx = tvm.context(target, 0)
4141
X = tvm.placeholder((m, k), name='X', dtype="uint8")
4242
W = tvm.placeholder((n, k), name='W', dtype="int8")
43-
pc = dot_16x1x16_int8_int8_int16()
43+
pc = dot_16x1x16_uint8_int8_int16()
4444
ak = tvm.reduce_axis((0, k), name='k')
4545

4646
packedW = tvm.placeholder((n//128, 128*(k//2), 2), name='packedW', dtype="int8")

tests/python/contrib/test_gemm_acc32_vnni.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import tvm
2020
import numpy as np
21-
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32_vnni
22-
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32
21+
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_vnni
22+
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32
2323
import pytest
2424

2525

@@ -46,7 +46,7 @@ def verify(target="llvm -mcpu=cascadelake"):
4646
return
4747

4848
ctx = tvm.context(target, 0)
49-
pc = dot_16x1x16_int8_int8_int32_vnni()
49+
pc = dot_16x1x16_uint8_int8_int32_vnni()
5050
ak = tvm.reduce_axis((0, k), name='k')
5151
packedW = tvm.placeholder(
5252
(n // 16, 16 * (k // 4), 4), name='packedW', dtype="int8")

tests/python/relay/test_op_level2.py

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -576,66 +576,84 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
576576
assembly = lib.get_source("asm")
577577
return assembly
578578

579-
# compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
580-
target = "llvm -mcpu=skylake-avx512"
581-
name = "llvm.x86.avx512.pmaddubs.w.512"
582-
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
583-
if llvm_id != 0:
584-
fast_int8_dtypes = ('uint8', 'int8', 'int32')
585-
# Sweep the input channels to check int8 robustness
586-
for ic in range(1, 24):
587-
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
588-
dtypes=fast_int8_dtypes)
589-
assert "pmaddubs" in asm
590-
591-
for ic in range(1, 24):
592-
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
593-
dtypes=fast_int8_dtypes)
594-
assert "pmaddubs" in asm
595-
596-
597-
# Sweep the output channels to check int8 robustness
598-
for oc in range(2, 24):
599-
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW',
579+
def has_fast_int8_instruction(asm, target):
580+
intel_device_type = None
581+
if 'skylake-avx512' in target:
582+
return "pmaddubs" in asm
583+
elif 'cascadelake' in target:
584+
return "vpdpbusd" in asm
585+
else:
586+
assert False, "Target should be Skylake or Cascadelake"
587+
588+
# compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
589+
targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
590+
name_skylake = "llvm.x86.avx512.pmaddubs.w.512"
591+
name_cascadelake = 'llvm.x86.avx512.vpdpbusd.512'
592+
llvm_id_skylake = tvm.codegen.llvm_lookup_intrinsic_id(name_skylake)
593+
llvm_id_cascadelake = tvm.codegen.llvm_lookup_intrinsic_id(name_cascadelake)
594+
for target in targets:
595+
if llvm_id_skylake != 0 and llvm_id_cascadelake:
596+
fast_int8_dtypes = ('uint8', 'int8', 'int32')
597+
# Sweep the input channels to check int8 robustness
598+
# Input channels should be a multiple of 4 internally.
599+
for ic in [1, 4, 6]:
600+
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
601+
kernel_layout='OIHW',
602+
dtypes=fast_int8_dtypes)
603+
assert has_fast_int8_instruction(asm, target)
604+
605+
for ic in [1, 4, 6]:
606+
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
607+
kernel_layout='HWIO',
608+
dtypes=fast_int8_dtypes)
609+
assert has_fast_int8_instruction(asm, target)
610+
611+
612+
# Sweep the output channels to check int8 robustness
613+
# Output channels should be a multiple of 16 internally.
614+
for oc in [4, 16, 20]:
615+
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
616+
kernel_layout='OIHW',
617+
dtypes=fast_int8_dtypes)
618+
assert has_fast_int8_instruction(asm, target)
619+
620+
for oc in [4, 16, 20]:
621+
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
622+
kernel_layout='HWIO',
623+
dtypes=fast_int8_dtypes)
624+
assert has_fast_int8_instruction(asm, target)
625+
626+
# Check that both non-divisible oc and ic work
627+
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
600628
dtypes=fast_int8_dtypes)
601-
assert "pmaddubs" in asm
629+
assert has_fast_int8_instruction(asm, target)
602630

603-
for oc in range(2, 24):
604-
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO',
631+
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
605632
dtypes=fast_int8_dtypes)
606-
assert "pmaddubs" in asm
607-
608-
# Check that both non-divisible oc and ic work
609-
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
610-
dtypes=fast_int8_dtypes)
611-
assert "pmaddubs" in asm
612-
613-
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
633+
assert has_fast_int8_instruction(asm, target)
634+
635+
# Ensure that code is generated when datatypes are not HW supported.
636+
dtypes = ('int8', 'int8', 'int32')
637+
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
638+
dtypes=dtypes)
639+
# Check that intrinisic is not present in the assembly.
640+
assert not has_fast_int8_instruction(asm, target)
641+
642+
# Ensure that code is generated when datatypes are not HW supported.
643+
dtypes = ('uint8', 'uint8', 'int32')
644+
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
645+
dtypes=dtypes)
646+
# Check that intrinisic is not present in the assembly.
647+
assert not has_fast_int8_instruction(asm, target)
648+
649+
# Check that a vectorized instruction is generated for older Intel
650+
# generations, because we default to NCHWc layout.
651+
target = "llvm -mcpu=core-avx2"
652+
fast_int8_dtypes = ('uint8', 'int8', 'int32')
653+
asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
614654
dtypes=fast_int8_dtypes)
615-
assert "pmaddubs" in asm
616-
617-
# Ensure that code is generated when datatypes are not HW supported.
618-
dtypes = ('int8', 'int8', 'int32')
619-
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
620-
dtypes=dtypes)
621-
# Check that intrinisic is not present in the assembly.
622-
assert "pmaddubs" not in asm
623-
624-
# Ensure that code is generated when datatypes are not HW supported.
625-
dtypes = ('uint8', 'uint8', 'int32')
626-
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
627-
dtypes=dtypes)
628-
# Check that intrinisic is not present in the assembly.
629-
assert "pmaddubs" not in asm
630-
631-
# Check that a vectorized instruction is generated for older Intel
632-
# generations, because we default to NCHWc layout.
633-
target = "llvm -mcpu=core-avx2"
634-
fast_int8_dtypes = ('uint8', 'int8', 'int32')
635-
asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
636-
dtypes=fast_int8_dtypes)
637-
# Check that vector int mult and add instructions are generated.
638-
assert "vpmulld" in asm and "vpadd" in asm
655+
# Check that vector int mult and add instructions are generated.
656+
assert "vpmulld" in asm and "vpadd" in asm
639657

640658

641659
def test_bitserial_conv2d_infer_type():

topi/python/topi/x86/conv2d_avx_1x1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..nn.util import infer_pad, get_pad_tuple
2525
from ..generic import conv2d as conv2d_generic
2626
from ..util import get_const_tuple, simplify
27-
from .tensor_intrin import dot_16x1x16_int8_int8_int32
27+
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
2828
from .util import get_fp32_len
2929

3030
def _fallback_schedule(cfg, wkl):
@@ -183,7 +183,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
183183
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
184184
return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
185185
int32_lanes=16,
186-
intrin=dot_16x1x16_int8_int8_int32())
186+
intrin=dot_16x1x16_uint8_int8_int32())
187187

188188

189189
def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
@@ -282,7 +282,7 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
282282
ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor)
283283
s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner)
284284

285-
pc = dot_16x1x16_int8_int8_int32()
285+
pc = dot_16x1x16_uint8_int8_int32()
286286
s[C].tensorize(oc_inner, pc)
287287

288288
if C != O:

topi/python/topi/x86/conv2d_avx_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..nn.util import infer_pad
2424
from ..generic import conv2d as conv2d_generic
2525
from ..util import get_const_tuple
26-
from .tensor_intrin import dot_16x1x16_int8_int8_int32
26+
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
2727
from .util import get_fp32_len
2828

2929
def _fallback_schedule(cfg, wkl):
@@ -209,4 +209,4 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
209209
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
210210
return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last,
211211
int32_lanes=16,
212-
intrin=dot_16x1x16_int8_int8_int32())
212+
intrin=dot_16x1x16_uint8_int8_int32())

topi/python/topi/x86/conv2d_int8.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
5757
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
5858

5959
# 2) Check LLVM support
60-
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
61-
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
62-
is_llvm_support = llvm_id != 0
60+
llvm_intrin_fast_int8_skylake = "llvm.x86.avx512.pmaddubs.w.512"
61+
llvm_intrin_fast_int8_cascadelake = "llvm.x86.avx512.vpdpbusd.512"
62+
llvm_id_skylake = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8_skylake)
63+
llvm_id_cascadelake = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8_cascadelake)
64+
is_llvm_support = llvm_id_skylake != 0 and llvm_id_cascadelake != 0
6365

6466
# 3) Check target
6567
target = tvm.target.current_target()
6668
is_target_support = False
6769
for opt in target.options:
68-
if opt == '-mcpu=skylake-avx512':
70+
if opt == '-mcpu=skylake-avx512' or opt == '-mcpu=cascadelake':
6971
is_target_support = True
7072

7173
return is_dtype_support and is_llvm_support and is_target_support

topi/python/topi/x86/tensor_intrin.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,34 @@
1919
import tvm
2020

2121

22-
def dot_16x1x16_int8_int8_int32():
22+
def dot_16x1x16_uint8_int8_int32():
23+
"""Dispatch the most optimized intrin depending on the target"""
24+
target = tvm.target.current_target()
25+
intel_device_type = None
26+
for opt in target.options:
27+
if opt == '-mcpu=skylake-avx512':
28+
intel_device_type = "skylake"
29+
elif opt == '-mcpu=cascadelake':
30+
intel_device_type = "cascadelake"
31+
32+
assert intel_device_type is not None, \
33+
"An old Intel machine that does not have fast Int8 support."
34+
35+
if intel_device_type == "skylake":
36+
return dot_16x1x16_uint8_int8_int32_skylake()
37+
else: # cascade lake
38+
return dot_16x1x16_uint8_int8_int32_vnni()
39+
40+
41+
def dot_16x1x16_uint8_int8_int32_skylake():
2342
"""
2443
Int8 dot product by every 4 elements using AVX512 Skylake instructions.
25-
This function takes two arrays of int8 datatype -- data[4] and
44+
This function takes two arrays of uint8 and int8 datatype -- data[4] and
2645
kernel[16][4] -- and computes a dot product of data[4] with every
2746
4 elements of kernels, resulting in output[16] of int32 datatype.
2847
The pseudo code is as follows.
2948
.. code-block:: c
30-
void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
49+
void dot_16x1x16_uint8_int8_int32(uint8 data[4], int8 kernel[16][4],
3150
int32 output[16]){
3251
for (int i = 0; i < 16; i++){
3352
output[i] = 0;
@@ -100,15 +119,15 @@ def _instr(index):
100119
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
101120

102121

103-
def dot_16x1x16_int8_int8_int16():
122+
def dot_16x1x16_uint8_int8_int16():
104123
"""
105124
Int8 dot product by every 2 elements using AVX512 Skylake instructions.
106-
This function takes two arrays of int8 datatype -- data[2] and
125+
This function takes two arrays of uint8 and int8 datatype -- data[2] and
107126
kernel[4][32][2] -- and computes a dot product of data[2] with every
108127
2 elements of kernels, resulting in output[4][32] of int16 datatype.
109128
The pseudo code is as follows.
110129
.. code-block:: c
111-
void dot_16x1x16_int8_int8_int16(int8 data[2], int8 kernel[32*4][2],
130+
void dot_16x1x16_uint8_int8_int16(uint8 data[2], int8 kernel[32*4][2],
112131
int16 output[32*4]){
113132
for (int i = 0; i< 4; i++){
114133
for (int j = 0; j < 32; j++){
@@ -182,15 +201,15 @@ def _instr(index):
182201
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
183202

184203

185-
def dot_16x1x16_int8_int8_int32_vnni():
204+
def dot_16x1x16_uint8_int8_int32_vnni():
186205
"""
187206
Int8 dot product by every 4 elements using AVX512VNNI Cascade Lake instructions.
188-
This function takes two arrays of int8 datatype -- data[4] and
207+
This function takes two arrays of uint8 and int8 datatype -- data[4] and
189208
kernel[16][4] -- and computes a dot product of data[4] with every
190209
4 elements of kernels, resulting in output[16] of int32 datatype.
191210
The pseudo code is as follows.
192211
.. code-block:: c
193-
void dot_16x1x16_int8_int8_int32_vnni(int8 data[4], int8 kernel[16][4],
212+
void dot_16x1x16_uint8_int8_int32_vnni(uint8 data[4], int8 kernel[16][4],
194213
int32 output[16]){
195214
for (int i = 0; i < 16; i++){
196215
output[i] = 0;

0 commit comments

Comments
 (0)