Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/tvm/topi/arm_cpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.topi.utils import get_const_tuple
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple
from tvm.dlight.base.analysis import normalize_prim_func


@autotvm.register_topi_compute("matmul.arm_cpu.sme")
Expand Down Expand Up @@ -126,9 +127,10 @@ def tir_schedule_matmul_sme(sch):
in_dtype = main_func.buffer_map[data_handle].dtype
out_dtype = "float32"

root_block = sch.get_block(main_func.body.block.name_hint)
gemm_block = sch.get_child_blocks(root_block)[-2]

block_infos = normalize_prim_func(sch)
reduction_block_infos = [block_info for block_info in block_infos if block_info.is_reduction()]
assert len(reduction_block_infos) == 1, "Expected a single gemm reduction block."
gemm_block = reduction_block_infos[0].block_rv
gemm_block_name = sch.get(gemm_block).name_hint
transpose = gemm_block_name.split("_")[-1]
transpose_b = transpose[1] == "T"
Expand Down
15 changes: 15 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,21 @@ def check_correct_assembly(dtype):
check_correct_assembly(dtype=dtype)


def test_matmul_sme_no_reduction_block():
@T.prim_func
def prim_func(a: T.handle, b: T.handle):
A = T.match_buffer(a, (4,))
B = T.match_buffer(b, (4,))
for i in range(3):
with T.block("block"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]

sch = tvm.tir.Schedule(prim_func)
with pytest.raises(AssertionError, match="Expected a single gemm reduction block."):
tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)


@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
Expand Down
26 changes: 17 additions & 9 deletions tests/python/relay/strategy/arm_cpu/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ class TestDense(BasicDenseTests):
)
@tvm.testing.requires_aprofile_aem_fvp
@pytest.mark.parametrize(
"data_shape,weight_shape",
"data_shape,weight_shape,enable_bias",
[
((32, 32), (32, 32)),
((2, 35), (6, 35)),
((3, 3), (68, 3)),
((79, 65), (152, 65)),
((32, 32), (32, 32), False),
((2, 35), (6, 35), False),
((3, 3), (68, 3), False),
((79, 65), (152, 65), True),
],
)
@pytest.mark.parametrize("in_dtype", ["float32", "float16"])
def test_sme_dense(data_shape, weight_shape, in_dtype):
def test_sme_dense(data_shape, weight_shape, enable_bias, in_dtype):
np.random.seed(0)
out_dtype = "float32"

Expand All @@ -117,8 +117,14 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
weight_data = np.random.uniform(size=weight_shape).astype(in_dtype)
weight = relay.const(weight_data, dtype=in_dtype)

dense = relay.nn.dense(inp, weight, out_dtype=out_dtype)
func = relay.Function(relay.analysis.free_vars(dense), dense)
relay_op = relay.nn.dense(inp, weight, out_dtype=out_dtype)

if enable_bias:
bias_data = np.random.uniform(size=weight_shape[0]).astype(out_dtype)
bias = relay.const(bias_data, dtype=out_dtype)
relay_op = relay.nn.bias_add(relay_op, bias)

func = relay.Function(relay.analysis.free_vars(relay_op), relay_op)

ir_mod = tvm.IRModule.from_expr(func)
ir_mod = tvm.relay.transform.InferType()(ir_mod)
Expand Down Expand Up @@ -147,8 +153,10 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
runtime=runtime,
params=params,
)

bias_postfix = "_add" if enable_bias else ""
generated_func = executor_factory.lowered_ir_mods.items()[0][1][
"tvmgen_default_fused_nn_matmul"
f"tvmgen_default_fused_nn_matmul{bias_postfix}"
]
extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)

Expand Down