Skip to content

Commit bf11ee6

Browse files
committed
Removing the conv2d legalization for x86. Will send a separate PR.
1 parent f91de7e commit bf11ee6

File tree

4 files changed

+7
-139
lines changed

4 files changed

+7
-139
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,6 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
191191
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
192192

193193

194-
@reg.register_legalize("nn.conv2d")
195-
def rewrite_conv2d(attrs, inputs, tinfos):
196-
"""Rewrite conv2d"""
197-
from ... import op
198-
return topi.nn.conv2d_legalize(attrs, inputs, tinfos, op)
199-
200-
201194
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
202195

203196

tests/python/relay/test_op_level2.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from tvm import relay
2222
from tvm.relay import transform
2323
from tvm.relay.testing import ctx_list
24-
import tvm.contrib.graph_runtime as runtime
2524
import topi.testing
2625

2726
def run_infer_type(expr):
@@ -30,15 +29,6 @@ def run_infer_type(expr):
3029
entry = mod["main"]
3130
return entry if isinstance(expr, relay.Function) else entry.body
3231

33-
def run_opt_pass(expr, passes):
34-
passes = passes if isinstance(passes, list) else [passes]
35-
mod = relay.Module.from_expr(expr)
36-
seq = transform.Sequential(passes)
37-
with transform.PassContext(opt_level=3):
38-
mod = seq(mod)
39-
entry = mod["main"]
40-
return entry if isinstance(expr, relay.Function) else entry.body
41-
4232
def test_conv2d_infer_type():
4333
# symbolic in batch dimension
4434
n, c, h, w = tvm.var("n"), 10, 224, 224
@@ -116,7 +106,7 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape,
116106
**attrs):
117107
if except_targets is None:
118108
except_targets = []
119-
109+
120110
x = relay.var("x", shape=dshape, dtype=dtype)
121111
w = relay.var("w", dtype=dtype)
122112
y = relay.nn.conv2d(x, w,
@@ -580,12 +570,13 @@ def _compile(input_dtype, weight_dtype, output_dtype, target):
580570
# Check that intrinisic is present in the assembly.
581571
assert "pmaddubs" in asm
582572

583-
# Ensure that code is generated for i8 x i8 conv.
573+
# Ensure that code is generated when datatypes are not HW supported.
584574
asm = _compile(input_dtype="int8",
585575
weight_dtype="int8",
586576
output_dtype="int32",
587577
target=target)
588-
assert "pmaddubs" in asm
578+
# Check that intrinisic is not present in the assembly.
579+
assert "pmaddubs" not in asm
589580

590581
# Ensure that code is generated when datatypes are not HW supported.
591582
asm = _compile(input_dtype="uint8",
@@ -605,47 +596,6 @@ def _compile(input_dtype, weight_dtype, output_dtype, target):
605596
# Check that vector int mult and add instructions are generated.
606597
assert "vpmulld" in asm and "vpadd" in asm
607598

608-
def test_rewrite_conv2d_intel_int8():
609-
def verify(orig, rewritten, input_shape, weight_shape):
610-
data = np.random.random_integers(-10, 10,
611-
size=input_shape).astype('int8')
612-
weight = np.random.random_integers(-10, 10,
613-
size=weight_shape).astype('int8')
614-
def _get_output(func):
615-
params = {"w": weight}
616-
with relay.build_config(opt_level=0):
617-
graph, lib, params = relay.build(func, 'llvm', params=params)
618-
619-
ctx = tvm.cpu(0)
620-
module = runtime.create(graph, lib, ctx)
621-
module.set_input('data', data)
622-
module.set_input(**params)
623-
module.run()
624-
return module.get_output(0).asnumpy()
625-
orig_output = _get_output(orig)
626-
rewritten_output = _get_output(rewritten)
627-
np.testing.assert_equal(orig_output, rewritten_output)
628-
629-
input_shape = (1, 128, 28, 28)
630-
weight_shape = (256, 128, 3, 3)
631-
idtype = "int8"
632-
wdtype = "int8"
633-
odtype = "int32"
634-
635-
var_input = relay.var("data", shape=input_shape, dtype=idtype)
636-
var_weight = relay.var("w", shape=weight_shape, dtype=wdtype)
637-
638-
f = relay.nn.conv2d(var_input,
639-
var_weight,
640-
kernel_size=(3, 3),
641-
channels=256,
642-
out_dtype=odtype)
643-
644-
orig = relay.Function([var_input, var_weight], f)
645-
with tvm.target.create("llvm"):
646-
rewritten = run_opt_pass(orig, transform.InferType())
647-
rewritten = run_opt_pass(rewritten, transform.RewriteOp())
648-
verify(orig, rewritten, input_shape, weight_shape)
649599

650600
if __name__ == "__main__":
651601
test_pool2d()
@@ -663,4 +613,3 @@ def _get_output(func):
663613
test_batch_flatten()
664614
test_upsampling()
665615
test_conv2d_int8_intrinsics()
666-
test_rewrite_conv2d_intel_int8()

topi/python/topi/nn/conv2d.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,29 +94,6 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
9494
# not to change by default
9595
return None
9696

97-
@tvm.target.generic_func
98-
def conv2d_rewrite_op(attrs, inputs, tinfos, F):
99-
"""Rewrite Conv2D op.
100-
101-
Parameters
102-
----------
103-
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
104-
Attributes of current convolution
105-
inputs : nnvm.symbol or tvm.relay.Expr
106-
Grouped input symbols
107-
tinfos : list
108-
Input shape and dtype
109-
F: symbol
110-
The context, can be either nnvm.sym or relay.op
111-
112-
Note
113-
----
114-
Unlike other TOPI functions, this function operates on both graph level and operator level,
115-
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
116-
"""
117-
# not to change by default
118-
return None
119-
12097
@tvm.target.generic_func
12198
def conv2d_infer_layout(workload, cfg):
12299
"""Infer input/output shapes and layouts from a workload and cfg.

topi/python/topi/x86/conv2d.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
from .. import generic, tag
2828
from .. import nn
2929
from ..util import get_const_tuple, get_shape
30-
from ..nn.conv2d import conv2d, conv2d_NCHWc
31-
from ..nn.conv2d import conv2d_alter_layout, conv2d_infer_layout, conv2d_rewrite_op
32-
from ..nn.conv2d import _get_workload as _get_conv2d_workload
30+
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
31+
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
3332
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
3433
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
3534
from ..nn.pad import pad
@@ -38,7 +37,7 @@
3837

3938
logger = logging.getLogger('topi')
4039

41-
def _is_int8_hw_support(data_dtype, kernel_dtype, target, ignore_dtype=False):
40+
def _is_int8_hw_support(data_dtype, kernel_dtype, target):
4241
"""
4342
Checks to ensure that we can use Intel DLBoost instructions
4443
1) The datatypes are correct.
@@ -59,8 +58,6 @@ def _is_int8_hw_support(data_dtype, kernel_dtype, target, ignore_dtype=False):
5958
if opt == '-mcpu=skylake-avx512':
6059
is_target_support = True
6160

62-
if ignore_dtype:
63-
return is_llvm_support and is_target_support
6461
return is_dtype_support and is_llvm_support and is_target_support
6562

6663
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
@@ -412,54 +409,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
412409
s = _schedule_conv2d_NCHWc(cfg, [C])
413410
return s, [new_data, new_kernel, C]
414411

415-
@conv2d_rewrite_op.register("cpu")
416-
def _conv2d_rewrite_op(attrs, inputs, arg_types, F):
417-
if F.__name__ != 'tvm.relay.op':
418-
return None
419-
data_type, kernel_type = arg_types[0], arg_types[1]
420-
target = tvm.target.current_target()
421-
data_layout = attrs['data_layout']
422-
kernel_layout = attrs['kernel_layout']
423-
# Uncomment when this bug is resolved
424-
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
425-
# if not ((data_layout == 'NCHW' and kernel_layout == 'OIHW')
426-
# or (data_layout == 'NHWC' and kernel_layout == 'HWIO')):
427-
# return None
428-
if not (data_layout == 'NCHW' and kernel_layout == 'OIHW'):
429-
return None
430-
431-
if not (data_type.dtype == 'int8' and kernel_type.dtype == 'int8'):
432-
return None
433-
434-
if not _is_int8_hw_support(data_type.dtype, kernel_type.dtype,
435-
target,
436-
ignore_dtype=True):
437-
return None
438-
439-
# Convert i8 x i8 to u8 x i8
440-
# Intel has fast instructions for u8 x i8 conv. For i8 x i8 conv, we can
441-
# convert the i8 tensor to u8 by adding 128 and use u8 x i8 conv. Since 128
442-
# has been added, the output now has to be adjusted.
443-
out_channel = attrs["channels"]
444-
data_expr, kernel_expr = inputs
445-
data_expr = F.cast(data_expr, "int32")
446-
data_expr = F.add(data_expr, F.const(128, "int32"))
447-
data_expr = F.clip(data_expr, a_min=0, a_max=255)
448-
data_expr = F.cast(data_expr, "uint8")
449-
conv = F.nn.conv2d(data_expr, kernel_expr, **attrs)
450-
bias_adjust = F.cast(kernel_expr, "int32")
451-
if kernel_layout == 'OIHW' and data_layout == 'NCHW':
452-
bias_adjust = F.sum(bias_adjust, axis=(1, 2, 3))
453-
bias_adjust = F.reshape(bias_adjust,
454-
newshape=(1, out_channel, 1, 1))
455-
elif kernel_layout == 'HWIO' and data_layout == 'NHWC':
456-
bias_adjust = F.sum(bias_adjust, axis=(0, 1, 2))
457-
bias_adjust = F.reshape(bias_adjust,
458-
newshape=(1, 1, 1, out_channel))
459-
bias_adjust = F.cast(bias_adjust, 'int32')
460-
bias_adjust = F.multiply(bias_adjust, F.const(128, 'int32'))
461-
return F.subtract(conv, bias_adjust)
462-
463412

464413
@conv2d_alter_layout.register("cpu")
465414
def _alter_conv2d_layout(attrs, inputs, tinfo, F):

0 commit comments

Comments
 (0)