Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,19 @@ constexpr const char* kIsExternalCodegen = "is_external_codegen";
*/
constexpr const char* kRelayToTIR = "RelayToTIR";

/*!
* \brief String representation of the host's target architecture.
*
* Currently this is set to "arm_cpu" on Arm®-based host architectures and "cpu"
* (which is synonymous with x86) everywhere else.
*
* TODO(@FranklandJack) dynamically detect host architecture and generalize for all targets.
*/
#if defined(__arm__) || defined(__aarch64__)
constexpr const char* kHostCPU = "arm_cpu";
#else
constexpr const char* kHostCPU = "cpu";
#endif
} // namespace attr

/*!
Expand Down
25 changes: 21 additions & 4 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
is_winograd_applicable = (
"float" in data.dtype
and "float" in kernel.dtype
and not data.dtype.count("custom")
and kh == 3
and kw == 3
and stride_h == 1
Expand Down Expand Up @@ -284,8 +285,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="depthwise_conv2d_nchw.x86",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
if target.features.has_asimd:
# TODO(@FranklandJack)
# Handle HWOI in arm_cpu schedules/compute definition.
if kernel_layout != "HWOI":
logger.warning(
"""depthwise_conv2d with layout NHWC and HWOI
kernel layout is not optimized for arm_cpu target.
"""
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True),
wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)

elif target.features.has_asimd:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
Expand All @@ -304,8 +318,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
and kernel.shape[3] == 1 # channel_multiplier == 1
and out_type.dtype == "int32"
and (
(data.shape[3] % 4 == 0 and data.dtype == "int8" and target.features.has_dsp)
or (data.shape[3] % 2 == 0 and data.dtype == "int16")
(
(data.shape[3] % 4 == 0 and data.dtype == "int8")
or (data.shape[3] % 2 == 0 and data.dtype == "int16")
)
and target.features.has_dsp
)
and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0)
# Ideally we should check that kernel is a Relay constant, but strategy functions
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types):
# Otherwise it needs to be broadcast.
else:
shift_data = relay.nn.bias_add(
relay.cast(data, dtype="int16"),
-relay.cast(input_zero_point, dtype="int16"),
relay.cast(data, dtype="int16"), -relay.cast(input_zero_point, dtype="int16")
)

# If kernel zero point is a scalar, we can directly subtract it.
Expand All @@ -123,8 +122,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types):
# Otherwise it needs to be broadcast.
else:
shift_kernel = relay.nn.bias_add(
relay.cast(kernel, dtype="int16"),
-relay.cast(kernel_zero_point, dtype="int16"),
relay.cast(kernel, dtype="int16"), -relay.cast(kernel_zero_point, dtype="int16")
)

return relay.nn.conv2d_transpose(shift_data, shift_kernel, **attrs)
Expand Down Expand Up @@ -486,7 +484,10 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
if target.features.has_asimd and not other_options:
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
# ARM prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
if types[0].dtype in ["int8", "uint8"]:
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)

return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)


@qnn_dense_legalize.register("arm_cpu")
Expand All @@ -495,7 +496,10 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
if target.features.has_asimd and not target.features.has_dotprod:
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
# ARM prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
if types[0].dtype in ["int8", "uint8"]:
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)

return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)


##########################
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from tvm import autotvm
import tvm.contrib.nnpack

from ..utils import traverse_inline, get_const_tuple
from ..utils import traverse_inline, get_const_tuple, conv2d_infer_layout_helper
from .. import nn
from ..nn.utils import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.conv2d import conv2d_infer_layout
from .conv2d_spatial_pack import (
conv2d_spatial_pack_nchw,
conv2d_spatial_pack_nhwc,
Expand Down Expand Up @@ -509,3 +510,8 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
def schedule_conv2d_nhwc_dsp(cfg, outs):
"""Create schedule for conv2d_nhwc_dsp"""
return conv2d_nhwc_dsp_schedule(cfg, outs)


@conv2d_infer_layout.register("arm_cpu")
def _conv2d_infer_layout(workload, cfg):
return conv2d_infer_layout_helper(workload, cfg)
6 changes: 4 additions & 2 deletions python/tvm/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def schedule_injective(outs):
if list(s[x].op.axis):
# do not vectorize for broadcast
dtype = "uint16" if x.dtype == "bfloat16" else x.dtype
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize)
s[x].vectorize(ii)
# do not vectorize for custom data types
if dtype.count("custom"):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize)
s[x].vectorize(ii)
tvm.te.schedule.AutoInlineInjective(s)

if not is_empty_shape(x.shape):
Expand Down
14 changes: 2 additions & 12 deletions python/tvm/topi/intel_graphics/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import relay
from tvm import autotvm

from ..utils import get_const_tuple
from ..utils import get_const_tuple, conv2d_infer_layout_helper
from ..nn import conv2d_alter_layout, conv2d_infer_layout
from .conv2d import _get_default_config

Expand Down Expand Up @@ -102,14 +102,4 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

@conv2d_infer_layout.register("intel_graphics")
def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[1]
out_channel, _, k_height, k_width = kernel[1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = f"NCHW{tile_ic}c"
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = f"NCHW{tile_oc}c"
return ((in_shape, in_layout),), ((out_shape, out_layout),)
return conv2d_infer_layout_helper(workload, cfg)
2 changes: 2 additions & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
_reduce_schedule = {
"generic": topi.generic.schedule_reduce,
"cpu": topi.x86.schedule_reduce,
# TODO(@FranklandJack) Write arm_cpu specific reduction schedule.
"arm_cpu": topi.x86.schedule_reduce,
"gpu": topi.cuda.schedule_reduce,
"hls": topi.cuda.schedule_reduce,
}
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/topi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm
from tvm import te
from tvm.tir import Any, SizeVar, bijective_layout, layout
import tvm.topi

from . import cpp, tag

Expand Down Expand Up @@ -526,3 +527,26 @@ def is_target(names):
def is_dynamic_shape(shape):
"""Checks if any part of a shape is dynamic"""
return any([isinstance(x, (Any, SizeVar)) for x in shape])


def conv2d_infer_layout_helper(workload, cfg):
"""Infers input and output layouts for a conv2d operator
scheduled using "tile_ic" and "tile_oc" scheduling configuration knobs which
is the case for cpu, arm_cpu and intel_graphics targets."""
_, data, kernel, strides, padding, dilation, _, _, _ = workload
batch_size, in_channel, in_height, in_width = data[1]
out_channel, _, k_height, k_width = kernel[1]
idxdiv = tvm.tir.indexdiv

pt, pl, pb, pr = tvm.topi.nn.get_pad_tuple(padding, (k_height, k_width))
hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
dilated_kernel_h = (k_height - 1) * hdilation + 1
dilated_kernel_w = (k_width - 1) * wdilation + 1
out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = f"NCHW{tile_ic}c"
out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
out_layout = f"NCHW{tile_oc}c"
return ((in_shape, in_layout),), ((out_shape, out_layout),)
20 changes: 2 additions & 18 deletions python/tvm/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.utils import get_pad_tuple
from ..utils import get_const_tuple, traverse_inline
from ..utils import get_const_tuple, traverse_inline, conv2d_infer_layout_helper
from . import conv2d_avx_1x1, conv2d_avx_common

logger = logging.getLogger("topi")
Expand Down Expand Up @@ -65,23 +65,7 @@ def _get_default_config(

@conv2d_infer_layout.register("cpu")
def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, _, dtype = workload
batch_size, in_channel, in_height, in_width = data[1]
out_channel, _, k_height, k_width = kernel[1]
idxdiv = tvm.tir.indexdiv

pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
dilated_kernel_h = (k_height - 1) * hdilation + 1
dilated_kernel_w = (k_width - 1) * wdilation + 1
out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = f"NCHW{tile_ic}c"
out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
out_layout = f"NCHW{tile_oc}c"
return ((in_shape, in_layout),), ((out_shape, out_layout),)
return conv2d_infer_layout_helper(workload, cfg)


def schedule_conv2d_nhwc(outs):
Expand Down
2 changes: 1 addition & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Integer>("opt-level")
// LLVM command line flags, see below
.add_attr_option<Array<String>>("cl-opt")
.set_default_keys({"cpu"})
.set_default_keys({attr::kHostCPU})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(false))
Expand Down
11 changes: 10 additions & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -701,7 +702,15 @@ std::optional<bool> IsHostFunc(const PrimFunc& func) {
if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
return true;
} else if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
return target.value()->HasKey("cpu");
const auto keys = target.value()->keys;
const auto it = std::find_if(std::begin(keys), std::end(keys),
[](const String key) { return key.compare("cpu"); });
if (std::end(keys) != it) {
const std::string key_string = *it;
return key_string == tvm::attr::kHostCPU;
} else {
return false;
}
} else {
return std::nullopt;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/integration/test_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def check_device(device):
print("skip because %s is not enabled.." % device)
return
target = tvm.target.Target(device)
if "cpu" not in target.keys:
if not any(["cpu" in key for key in target.keys]):
schedule[placeholder_b].bind(axis1, te.thread_axis("blockIdx.x"))
schedule[placeholder_b].bind(axis2, te.thread_axis("threadIdx.x"))
func = tvm.build(schedule, [placeholder_a, placeholder_b], device)
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/strategy/test_select_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
import tvm
from tvm import relay
from tvm import te
from tvm import target
from tvm.relay.testing import run_infer_type
import tvm.testing


native_arch = target.Target("llvm").keys[0]


@pytest.mark.parametrize(
"target, expected_implementation",
[("llvm", "concatenate.cpu"), ("llvm -device=arm_cpu", "concatenate.arm_cpu")],
[("llvm", "concatenate." + native_arch), ("llvm -device=arm_cpu", "concatenate.arm_cpu")],
)
def test_concatenate(target, expected_implementation):
target = tvm.target.Target(target)
Expand Down
8 changes: 5 additions & 3 deletions tests/python/topi/python/test_topi_bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Test code for bitserial_dense operator"""
import os
import numpy as np
from tvm.target.target import Target
import tvm
from tvm import te
from tvm import topi
Expand Down Expand Up @@ -53,11 +54,12 @@ def get_ref_data(a_shape, b_shape, input_dtype):
c_np = np.dot(a_np, b_np.T)
return a_np, b_np, c_np

for target in ["llvm", "llvm -device=arm_cpu"]:
if "arm_cpu" in target and "arm" not in os.uname()[4]:
for target_string in ["llvm", "llvm -device=arm_cpu"]:
target = Target(target_string)
if "arm_cpu" in target.keys and "arm" not in os.uname()[4]:
print("Skipped running code, not an arm device")
continue
input_dtype = "uint8" if "arm_cpu" in target else "uint32"
input_dtype = "uint8" if "arm_cpu" in target.keys else "uint32"
A = te.placeholder((batch, in_dim), dtype=input_dtype, name="A")
B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name="B")
fcompute, fschedule = tvm.topi.testing.dispatch(target, _bitserial_dense_implement)
Expand Down
Loading