Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_N = 4
tile_K = 16
# In non-quantized cases, A is not interleaved.
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 32
tile_K = 4
else:
# In non-quantized cases, A is not interleaved.
# Each load from B' contains 16 elements (i.e. 16 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 16
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm import te
from tvm.topi import nn
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed
from ..utils import get_const_tuple, get_const_int
from ..nn.utils import get_pad_tuple
from .tensor_intrin import (
Expand Down Expand Up @@ -339,7 +340,15 @@ def compute_conv2d_gemm_without_weight_transform(
),
name="C",
)
zero = tvm.tir.const(0)
# Ensure padding on the N axis does not get removed during tir passes
# by adding a dummy reference to the specific padded area of the result
if in_dtype == "float16" and target.features.has_fp16_simd:
zero = (
tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
)
else:
zero = tvm.tir.const(0)

# Reshape the result into a convolution output
out_shape = (batches, OH, OW, OC)
Expand Down Expand Up @@ -454,14 +463,14 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
C = out.op.input_tensors[0]
A = C.op.input_tensors[0]
in_type = A.dtype
y_tile_size, _ = get_tiling_B_transformed(False, in_type)

# Computation
b, x, y = C.op.axis
(k,) = C.op.reduce_axis

if in_type in ["int8", "uint8"]:
k_outer, k_inner = s[C].split(k, 16)
y_tile_size = 16
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
Expand All @@ -470,9 +479,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
s[C].parallel(x_outer)
else:
k_outer, k_inner = s[C].split(k, 4)
y_tile_size = 16
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
y_inner_outer, y_inner_inner = s[C].split(y_inner, 4)
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
b_x_outer_fused = s[C].fuse(b, x_outer)
s[C].parallel(b_x_outer_fused)
s[C].reorder(
Expand Down
12 changes: 7 additions & 5 deletions src/target/parsers/aprofile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ static TargetFeatures GetFeatures(TargetJSON target) {
const bool has_dotprod =
(dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);

return {
{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
{"has_matmul_i8", Bool(has_i8mm)},
};
const bool fp16_flag = HasFlag(mcpu, mattr, "+fullfp16");
const bool fp16_support = arch_version >= 8.2;
const bool has_fp16_simd = fp16_support && (fp16_flag || has_sve);

return {{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
{"has_matmul_i8", Bool(has_i8mm)}, {"has_fp16_simd", Bool(has_fp16_simd)}};
}

static Array<String> MergeKeys(Optional<Array<String>> existing_keys) {
Expand Down
27 changes: 27 additions & 0 deletions tests/cpp/target/parsers/aprofile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,38 @@ TEST_P(AProfileOptionalSVE, OptionalSVESupport) {
EXPECT_TRUE(Downcast<Bool>(features.at("has_sve")));
}

using AProfileOptionalFP16 = testing::TestWithParam<float>;
TEST_P(AProfileOptionalFP16, OptionalFP16Support) {
const std::string arch_attr = "+v" + std::to_string(GetParam()) + "a";

// Check that the "has_fp16_simd" feature is not set by default when "+fullfp16" isn't set as an
// attribute.
TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr});
TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
EXPECT_TRUE(IsArch(target));
EXPECT_FALSE(Downcast<Bool>(features.at("has_fp16_simd")));

// Check that the "has_fp16_simd" feature is set when "+fullfp16" is explicitly set as an
// attribute.
target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+fullfp16"});
features = Downcast<TargetFeatures>(target.at("features"));
EXPECT_TRUE(IsArch(target));
EXPECT_TRUE(Downcast<Bool>(features.at("has_fp16_simd")));

// Check that the "has_fp16_simd" feature is set when "+sve" is explicitly set as an attribute.
target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+sve"});
features = Downcast<TargetFeatures>(target.at("features"));
EXPECT_TRUE(IsArch(target));
EXPECT_TRUE(Downcast<Bool>(features.at("has_fp16_simd")));
}

INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalI8MM, ::testing::ValuesIn(optionalI8MM));
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalDotProd,
::testing::ValuesIn(optionalDotProd));
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalSVE,
::testing::Values(8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0));
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalFP16,
::testing::Values(8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0));

} // namespace aprofile
} // namespace parsers
Expand Down