Skip to content

Commit 8b60c09

Browse files
committed
[target] Use native architecture for llvm target
Set the default `-device=` key for llvm targets based on the native architecture rather than hard coding to `cpu` which is x86 specific. This means that when llvm target triples are not specified we will test `arm_cpu` schedules on Arm®-based architectures and `cpu` schedules on x86 based architectures. Fix any schedule test failures that result from this fix.
1 parent 521465e commit 8b60c09

File tree

15 files changed

+126
-70
lines changed

15 files changed

+126
-70
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,16 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
132132
plevel=15,
133133
)
134134
else:
135+
# TODO(@FranklandJack)
136+
# Investigate why this producing output tensor of
137+
# incorrect dimensions in
138+
# test_runtime_module_based_interface.py
135139
# ARM conv2d spatial pack schedule.
136-
strategy.add_implementation(
137-
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
138-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
139-
name="conv2d_nchw_spatial_pack.arm_cpu",
140-
plevel=10,
141-
)
140+
# strategy.add_implementation(
141+
# wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
142+
# wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
143+
# name="conv2d_nchw_spatial_pack.arm_cpu",
144+
# plevel=10,
142145

143146
strategy.add_implementation(
144147
wrap_compute_conv2d(topi.x86.conv2d_nchw),
@@ -152,6 +155,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
152155
is_winograd_applicable = (
153156
"float" in data.dtype
154157
and "float" in kernel.dtype
158+
and not data.dtype.count("custom")
155159
and kh == 3
156160
and kw == 3
157161
and stride_h == 1
@@ -284,8 +288,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
284288
name="depthwise_conv2d_nchw.x86",
285289
)
286290
elif layout == "NHWC":
287-
assert kernel_layout == "HWOI"
288-
if target.features.has_asimd:
291+
# TODO(@FranklandJack)
292+
# Handle HWOI in arm_cpu schedules/compute definition.
293+
if kernel_layout != "HWOI":
294+
logger.warning(
295+
"""depthwise_conv2d with layout NHWC and HWOI
296+
kernel layout is not optimized for arm_cpu target.
297+
"""
298+
)
299+
strategy.add_implementation(
300+
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True),
301+
wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc),
302+
name="depthwise_conv2d_nhwc.generic",
303+
)
304+
305+
elif target.features.has_asimd:
289306
strategy.add_implementation(
290307
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
291308
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types):
111111
# Otherwise it needs to be broadcast.
112112
else:
113113
shift_data = relay.nn.bias_add(
114-
relay.cast(data, dtype="int16"),
115-
-relay.cast(input_zero_point, dtype="int16"),
114+
relay.cast(data, dtype="int16"), -relay.cast(input_zero_point, dtype="int16")
116115
)
117116

118117
# If kernel zero point is a scalar, we can directly subtract it.
@@ -123,8 +122,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types):
123122
# Otherwise it needs to be broadcast.
124123
else:
125124
shift_kernel = relay.nn.bias_add(
126-
relay.cast(kernel, dtype="int16"),
127-
-relay.cast(kernel_zero_point, dtype="int16"),
125+
relay.cast(kernel, dtype="int16"), -relay.cast(kernel_zero_point, dtype="int16")
128126
)
129127

130128
return relay.nn.conv2d_transpose(shift_data, shift_kernel, **attrs)
@@ -486,7 +484,10 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
486484
if target.features.has_asimd and not other_options:
487485
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
488486
# ARM prefers the dtypes to be same.
489-
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
487+
if types[0].dtype == "int8":
488+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
489+
490+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
490491

491492

492493
@qnn_dense_legalize.register("arm_cpu")
@@ -495,7 +496,10 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
495496
if target.features.has_asimd and not target.features.has_dotprod:
496497
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
497498
# ARM prefers the dtypes to be same.
498-
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
499+
if types[0].dtype == "int8":
500+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
501+
502+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
499503

500504

501505
##########################

python/tvm/topi/arm_cpu/conv2d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
from tvm import autotvm
2424
import tvm.contrib.nnpack
2525

26-
from ..utils import traverse_inline, get_const_tuple
26+
from ..utils import traverse_inline, get_const_tuple, conv2d_infer_layout_helper
2727
from .. import nn
2828
from ..nn.utils import get_const_int, get_pad_tuple
2929
from ..nn.winograd_util import winograd_transform_matrices
30+
from ..nn.conv2d import conv2d_infer_layout
3031
from .conv2d_spatial_pack import (
3132
conv2d_spatial_pack_nchw,
3233
conv2d_spatial_pack_nhwc,
@@ -509,3 +510,8 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
509510
def schedule_conv2d_nhwc_dsp(cfg, outs):
510511
"""Create schedule for conv2d_nhwc_dsp"""
511512
return conv2d_nhwc_dsp_schedule(cfg, outs)
513+
514+
515+
@conv2d_infer_layout.register("arm_cpu")
516+
def _conv2d_infer_layout(workload, cfg):
517+
return conv2d_infer_layout_helper(workload, cfg)

python/tvm/topi/arm_cpu/injective.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def schedule_injective(outs):
6969
if list(s[x].op.axis):
7070
# do not vectorize for broadcast
7171
dtype = "uint16" if x.dtype == "bfloat16" else x.dtype
72-
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize)
73-
s[x].vectorize(ii)
72+
# do not vectorize for custom data types
73+
if 0 == dtype.count("custom"):
74+
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize)
75+
s[x].vectorize(ii)
7476
tvm.te.schedule.AutoInlineInjective(s)
7577

7678
if not is_empty_shape(x.shape):

python/tvm/topi/intel_graphics/conv2d_alter_op.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tvm import relay
2323
from tvm import autotvm
2424

25-
from ..utils import get_const_tuple
25+
from ..utils import get_const_tuple, conv2d_infer_layout_helper
2626
from ..nn import conv2d_alter_layout, conv2d_infer_layout
2727
from .conv2d import _get_default_config
2828

@@ -102,14 +102,4 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
102102

103103
@conv2d_infer_layout.register("intel_graphics")
104104
def _conv2d_infer_layout(workload, cfg):
105-
_, data, kernel, strides, padding, dilation, layout, dtype = workload
106-
batch_size, in_channel, in_height, in_width = data[1]
107-
out_channel, _, k_height, k_width = kernel[1]
108-
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
109-
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
110-
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
111-
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
112-
in_layout = f"NCHW{tile_ic}c"
113-
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
114-
out_layout = f"NCHW{tile_oc}c"
115-
return ((in_shape, in_layout),), ((out_shape, out_layout),)
105+
return conv2d_infer_layout_helper(workload, cfg)

python/tvm/topi/testing/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
_reduce_schedule = {
3636
"generic": topi.generic.schedule_reduce,
3737
"cpu": topi.x86.schedule_reduce,
38+
# TODO(@FranklandJack) Write arm_cpu specific reduction schedule.
39+
"arm_cpu": topi.x86.schedule_reduce,
3840
"gpu": topi.cuda.schedule_reduce,
3941
"hls": topi.cuda.schedule_reduce,
4042
}

python/tvm/topi/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,26 @@ def is_target(names):
526526
def is_dynamic_shape(shape):
527527
"""Checks if any part of a shape is dynamic"""
528528
return any([isinstance(x, (Any, SizeVar)) for x in shape])
529+
530+
531+
def conv2d_infer_layout_helper(workload, cfg):
532+
"""Infers input and output layouts for a conv2d operator
533+
scheduled using "tile_ic" and "tile_oc" scheduling configuration knobs which
534+
is the case for cpu, arm_cpu and intel_graphics targets."""
535+
_, data, kernel, strides, padding, dilation, layout, _, dtype = workload
536+
batch_size, in_channel, in_height, in_width = data[1]
537+
out_channel, _, k_height, k_width = kernel[1]
538+
idxdiv = tvm.tir.indexdiv
539+
540+
pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
541+
hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
542+
dilated_kernel_h = (k_height - 1) * hdilation + 1
543+
dilated_kernel_w = (k_width - 1) * wdilation + 1
544+
out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
545+
out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
546+
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
547+
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
548+
in_layout = f"NCHW{tile_ic}c"
549+
out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
550+
out_layout = f"NCHW{tile_oc}c"
551+
return ((in_shape, in_layout),), ((out_shape, out_layout),)

python/tvm/topi/x86/conv2d.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..nn.conv2d import unpack_NCHWc_to_nchw
3131
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
3232
from ..nn.utils import get_pad_tuple
33-
from ..utils import get_const_tuple, traverse_inline
33+
from ..utils import get_const_tuple, traverse_inline, conv2d_infer_layout_helper
3434
from . import conv2d_avx_1x1, conv2d_avx_common
3535

3636
logger = logging.getLogger("topi")
@@ -65,23 +65,7 @@ def _get_default_config(
6565

6666
@conv2d_infer_layout.register("cpu")
6767
def _conv2d_infer_layout(workload, cfg):
68-
_, data, kernel, strides, padding, dilation, layout, _, dtype = workload
69-
batch_size, in_channel, in_height, in_width = data[1]
70-
out_channel, _, k_height, k_width = kernel[1]
71-
idxdiv = tvm.tir.indexdiv
72-
73-
pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
74-
hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
75-
dilated_kernel_h = (k_height - 1) * hdilation + 1
76-
dilated_kernel_w = (k_width - 1) * wdilation + 1
77-
out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
78-
out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
79-
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
80-
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
81-
in_layout = f"NCHW{tile_ic}c"
82-
out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
83-
out_layout = f"NCHW{tile_oc}c"
84-
return ((in_shape, in_layout),), ((out_shape, out_layout),)
68+
return conv2d_infer_layout_helper(workload, cfg)
8569

8670

8771
def schedule_conv2d_nhwc(outs):

src/target/target_kind.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ TargetJSON TestTargetParser(TargetJSON target) {
257257

258258
/********** Register Target kinds and attributes **********/
259259

260+
#if defined(__arm__) || defined(__aarch64__)
261+
#define NATIVE_CPU "arm_cpu"
262+
#else
263+
#define NATIVE_CPU "cpu"
264+
#endif
265+
260266
TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
261267
.add_attr_option<Array<String>>("mattr")
262268
.add_attr_option<String>("mcpu")
@@ -275,12 +281,14 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
275281
.add_attr_option<Integer>("opt-level")
276282
// LLVM command line flags, see below
277283
.add_attr_option<Array<String>>("cl-opt")
278-
.set_default_keys({"cpu"})
284+
.set_default_keys({NATIVE_CPU})
279285
// Force the external codegen kind attribute to be registered, even if no external
280286
// codegen targets are enabled by the TVM build.
281287
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(false))
282288
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);
283289

290+
#undef NATIVE_CPU
291+
284292
// Note regarding the "cl-opt" attribute:
285293
// Each string in the array has the format
286294
// -optionname[[:type]=value]

tests/python/topi/python/test_topi_bitserial_dense.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Test code for bitserial_dense operator"""
1818
import os
1919
import numpy as np
20+
from tvm.target.target import Target
2021
import tvm
2122
from tvm import te
2223
from tvm import topi
@@ -53,11 +54,12 @@ def get_ref_data(a_shape, b_shape, input_dtype):
5354
c_np = np.dot(a_np, b_np.T)
5455
return a_np, b_np, c_np
5556

56-
for target in ["llvm", "llvm -device=arm_cpu"]:
57-
if "arm_cpu" in target and "arm" not in os.uname()[4]:
57+
for target_string in ["llvm", "llvm -device=arm_cpu"]:
58+
target = Target(target_string)
59+
if "arm_cpu" in target.keys and "arm" not in os.uname()[4]:
5860
print("Skipped running code, not an arm device")
5961
continue
60-
input_dtype = "uint8" if "arm_cpu" in target else "uint32"
62+
input_dtype = "uint8" if "arm_cpu" in target.keys else "uint32"
6163
A = te.placeholder((batch, in_dim), dtype=input_dtype, name="A")
6264
B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name="B")
6365
fcompute, fschedule = tvm.topi.testing.dispatch(target, _bitserial_dense_implement)

0 commit comments

Comments
 (0)