Skip to content

Commit a439c78

Browse files
anijain2305wweic
authored andcommitted
[Relay] Legalize and AlterOpLayout for Int8 Intel. (apache#3961)
1 parent ebb5a68 commit a439c78

File tree

8 files changed

+432
-183
lines changed

8 files changed

+432
-183
lines changed

tests/python/relay/test_op_level2.py

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -541,18 +541,35 @@ def test_upsampling():
541541

542542

543543
def test_conv2d_int8_intrinsics():
544-
def _compile(input_dtype, weight_dtype, output_dtype, target):
545-
n, ic, h, w, oc, ch, cw = 1, 16, 224, 224, 32, 3, 3
546-
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
547-
w = relay.var("w", relay.TensorType((oc, ic, ch, cw), weight_dtype))
544+
def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
545+
input_dtype, weight_dtype, output_dtype = dtypes
546+
547+
n, h, w, ch, cw = 1, 64, 64, 3, 3
548+
if data_layout == 'NCHW':
549+
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
550+
elif data_layout == 'NHWC':
551+
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
552+
else:
553+
raise ValueError('Not supported')
554+
555+
if kernel_layout == 'OIHW':
556+
kernel_shape = (oc, ic, ch, cw)
557+
elif kernel_layout == 'HWIO':
558+
kernel_shape = (ch, cw, ic, oc)
559+
else:
560+
raise ValueError('Not supported')
561+
562+
w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
548563
y = relay.nn.conv2d(x, w,
549564
kernel_size=(ch, cw),
550565
channels=oc,
551566
padding=(1, 1),
552567
dilation=(1, 1),
568+
data_layout=data_layout,
569+
kernel_layout=kernel_layout,
553570
out_dtype=output_dtype)
554571
func = relay.Function([x, w], y)
555-
wdata = np.random.rand(oc, ic, ch, cw) * 10
572+
wdata = np.random.rand(*kernel_shape) * 10
556573
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
557574
with relay.build_config(opt_level=3):
558575
graph, lib, params = relay.build(func, target, params=parameters)
@@ -564,37 +581,59 @@ def _compile(input_dtype, weight_dtype, output_dtype, target):
564581
name = "llvm.x86.avx512.pmaddubs.w.512"
565582
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
566583
if llvm_id != 0:
567-
# Intel Int8 instruction need uint8 data and int8 kernel
568-
asm = _compile(input_dtype="uint8",
569-
weight_dtype="int8",
570-
output_dtype="int32",
571-
target=target)
572-
# Check that intrinisic is present in the assembly.
584+
fast_int8_dtypes = ('uint8', 'int8', 'int32')
585+
# Sweep the input channels to check int8 robustness
586+
for ic in range(1, 24):
587+
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
588+
dtypes=fast_int8_dtypes)
589+
assert "pmaddubs" in asm
590+
591+
for ic in range(1, 24):
592+
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
593+
dtypes=fast_int8_dtypes)
594+
assert "pmaddubs" in asm
595+
596+
597+
# Sweep the output channels to check int8 robustness
598+
for oc in range(2, 24):
599+
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW',
600+
dtypes=fast_int8_dtypes)
601+
assert "pmaddubs" in asm
602+
603+
for oc in range(2, 24):
604+
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO',
605+
dtypes=fast_int8_dtypes)
606+
assert "pmaddubs" in asm
607+
608+
# Check that both non-divisible oc and ic work
609+
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
610+
dtypes=fast_int8_dtypes)
611+
assert "pmaddubs" in asm
612+
613+
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
614+
dtypes=fast_int8_dtypes)
573615
assert "pmaddubs" in asm
574616

575617
# Ensure that code is generated when datatypes are not HW supported.
576-
asm = _compile(input_dtype="int8",
577-
weight_dtype="int8",
578-
output_dtype="int32",
579-
target=target)
618+
dtypes = ('int8', 'int8', 'int32')
619+
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
620+
dtypes=dtypes)
580621
# Check that intrinisic is not present in the assembly.
581622
assert "pmaddubs" not in asm
582623

583624
# Ensure that code is generated when datatypes are not HW supported.
584-
asm = _compile(input_dtype="uint8",
585-
weight_dtype="uint8",
586-
output_dtype="int32",
587-
target=target)
625+
dtypes = ('uint8', 'uint8', 'int32')
626+
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
627+
dtypes=dtypes)
588628
# Check that intrinisic is not present in the assembly.
589629
assert "pmaddubs" not in asm
590630

591631
# Check that a vectorized instruction is generated for older Intel
592632
# generations, because we default to NCHWc layout.
593633
target = "llvm -mcpu=core-avx2"
594-
asm = _compile(input_dtype="int8",
595-
weight_dtype="int8",
596-
output_dtype="int32",
597-
target=target)
634+
fast_int8_dtypes = ('uint8', 'int8', 'int32')
635+
asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
636+
dtypes=fast_int8_dtypes)
598637
# Check that vector int mult and add instructions are generated.
599638
assert "vpmulld" in asm and "vpadd" in asm
600639

topi/python/topi/nn/conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
151151
if data_layout == 'NCHW':
152152
CO, CIG, KH, KW = [x.value for x in kernel.shape]
153153
else:
154-
KH, KW, CO, CIG = [x.value for x in kernel.shape]
154+
KH, KW, CIG, CO = [x.value for x in kernel.shape]
155155

156156
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
157157
GRPS = CI // CIG

topi/python/topi/x86/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from .roi_align import roi_align_nchw
1818
from .conv2d_transpose import _schedule_conv2d_transpose_nchw
1919
from .sparse import *
20+
from .conv2d_alter_op import *

topi/python/topi/x86/conv2d.py

Lines changed: 3 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,16 @@
2626
from tvm.autotvm.task import get_config
2727
from .. import generic, tag
2828
from .. import nn
29-
from ..util import get_const_tuple, get_shape
30-
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \
31-
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
29+
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
30+
conv2d_infer_layout, _get_workload as _get_conv2d_workload
3231
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
33-
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
3432
from ..nn.pad import pad
33+
from ..util import get_const_tuple
3534

3635
from . import conv2d_avx_1x1, conv2d_avx_common
3736

3837
logger = logging.getLogger('topi')
3938

40-
def _is_int8_hw_support(data_dtype, kernel_dtype, target):
41-
"""
42-
Checks to ensure that we can use Intel DLBoost instructions
43-
1) The datatypes are correct.
44-
2) LLVM version has support for the instructions.
45-
3) Target is skylake and above.
46-
"""
47-
# 1) Check datatypes
48-
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
49-
50-
# 2) Check LLVM support
51-
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
52-
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
53-
is_llvm_support = llvm_id != 0
54-
55-
# 3) Check target
56-
is_target_support = False
57-
for opt in target.options:
58-
if opt == '-mcpu=skylake-avx512':
59-
is_target_support = True
60-
61-
return is_dtype_support and is_llvm_support and is_target_support
62-
6339
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
6440
layout='NCHW'):
6541
"""
@@ -353,138 +329,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
353329
return s, [new_data, new_kernel, C]
354330

355331

356-
@conv2d_alter_layout.register("cpu")
357-
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
358-
359-
copy_inputs = [s for s in inputs]
360-
new_attrs = {k : attrs[k] for k in attrs.keys()}
361-
362-
if F.__name__ == 'tvm.relay.op':
363-
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
364-
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
365-
366-
data, kernel = tinfo[0], tinfo[1]
367-
batch_size, in_channel, height, width = get_const_tuple(data.shape)
368-
369-
groups = attrs.get_int("groups")
370-
out_channel = attrs.get_int("channels") \
371-
if F.__name__ == 'nnvm.symbol' else new_attrs["channels"]
372-
padding = attrs.get_int_tuple("padding")
373-
strides = attrs.get_int_tuple("strides")
374-
dilation = attrs.get_int_tuple("dilation")
375-
out_dtype = attrs["out_dtype"]
376-
377-
layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout'
378-
379-
layout = attrs[layout_name]
380-
kh, kw = attrs.get_int_tuple("kernel_size")
381-
382-
dtype = data.dtype
383-
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
384-
385-
kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
386-
is_depthwise = groups == kshape[0] and kshape[1] == 1
387-
388-
# only optimize for NCHW
389-
if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
390-
return None
391-
392-
if groups != 1 and not is_depthwise:
393-
return None
394-
395-
dispatch_ctx = autotvm.task.DispatchContext.current
396-
target = tvm.target.current_target()
397-
# query schedule and fallback if necessary
398-
workload = autotvm.task.args_to_workload(
399-
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
400-
if is_depthwise else \
401-
autotvm.task.args_to_workload(
402-
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
403-
cfg = dispatch_ctx.query(target, workload)
404-
if cfg.is_fallback:
405-
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
406-
407-
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
408-
409-
new_attrs[layout_name] = 'NCHW%dc' % ic_bn
410-
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
411-
412-
# Remove attached compilation target because conv2d_NCHWc needs to create
413-
# a conv2d_nchwc op and target is not one of conv2d's parameters.
414-
if "target" in new_attrs:
415-
del new_attrs["target"]
416-
417-
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
418-
dtype=data.dtype)
419-
420-
if is_depthwise:
421-
new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
422-
# Store altered operator's config
423-
new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel.dtype)
424-
new_workload = autotvm.task.args_to_workload(
425-
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
426-
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
427-
dispatch_ctx.update(target, new_workload, cfg)
428-
if F.__name__ == 'nnvm.symbol':
429-
logging.warning("Use native layout for depthwise convolution on NNVM.")
430-
return None
431-
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
432-
433-
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
434-
# Convert kernel data layout from 4D to 7D
435-
n_elems = 4
436-
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
437-
data_expr, kernel_expr = inputs
438-
kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0))
439-
kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
440-
kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
441-
kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
442-
in_channel//ic_bn, ic_bn))
443-
kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
444-
in_channel//ic_bn, ic_bn//n_elems, n_elems))
445-
kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
446-
copy_inputs = [data_expr, kernel_OIHWioe]
447-
448-
# Store altered operator's config. New kernel layout OIHWio4
449-
new_kernel = tvm.placeholder((out_channel // oc_bn,
450-
in_channel // ic_bn,
451-
kh,
452-
kw,
453-
ic_bn // n_elems,
454-
oc_bn,
455-
n_elems), dtype=kernel.dtype)
456-
457-
new_workload = autotvm.task.args_to_workload([new_data,
458-
new_kernel,
459-
strides,
460-
padding,
461-
dilation,
462-
new_attrs[layout_name],
463-
new_attrs['out_layout'],
464-
out_dtype],
465-
conv2d_NCHWc_int8)
466-
dispatch_ctx.update(target, new_workload, cfg)
467-
if F.__name__ == 'nnvm.symbol':
468-
logging.warning("Use native layout for int8 convolution on NNVM.")
469-
return None
470-
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
471-
472-
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
473-
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
474-
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
475-
# Store altered operator's config
476-
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
477-
kh, kw, ic_bn, oc_bn), dtype=kernel.dtype)
478-
new_workload = autotvm.task.args_to_workload(
479-
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
480-
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
481-
dispatch_ctx.update(target, new_workload, cfg)
482-
483-
if F.__name__ == 'nnvm.symbol':
484-
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
485-
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
486-
487-
488332
@conv2d_infer_layout.register("cpu")
489333
def _conv2d_infer_layout(workload, cfg):
490334
_, data, kernel, strides, padding, dilation, layout, dtype = workload

0 commit comments

Comments
 (0)