diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 2f09e24c87a2..23b8734a0ba4 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -26,6 +26,7 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple +from tvm.dlight.base.analysis import normalize_prim_func @autotvm.register_topi_compute("matmul.arm_cpu.sme") @@ -126,9 +127,10 @@ def tir_schedule_matmul_sme(sch): in_dtype = main_func.buffer_map[data_handle].dtype out_dtype = "float32" - root_block = sch.get_block(main_func.body.block.name_hint) - gemm_block = sch.get_child_blocks(root_block)[-2] - + block_infos = normalize_prim_func(sch) + reduction_block_infos = [block_info for block_info in block_infos if block_info.is_reduction()] + assert len(reduction_block_infos) == 1, "Expected a single gemm reduction block." + gemm_block = reduction_block_infos[0].block_rv gemm_block_name = sch.get(gemm_block).name_hint transpose = gemm_block_name.split("_")[-1] transpose_b = transpose[1] == "T" diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 77c22761a9c8..9b0408b949a0 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -540,6 +540,21 @@ def check_correct_assembly(dtype): check_correct_assembly(dtype=dtype) +def test_matmul_sme_no_reduction_block(): + @T.prim_func + def prim_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + for i in range(3): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + sch = tvm.tir.Schedule(prim_func) + with pytest.raises(AssertionError, match="Expected a single gemm reduction block."): + tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) + + @pytest.mark.skipif( llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" ) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index 3a8427e8154d..fee8a87f1253 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -99,16 +99,16 @@ class TestDense(BasicDenseTests): ) @tvm.testing.requires_aprofile_aem_fvp @pytest.mark.parametrize( - "data_shape,weight_shape", + "data_shape,weight_shape,enable_bias", [ - ((32, 32), (32, 32)), - ((2, 35), (6, 35)), - ((3, 3), (68, 3)), - ((79, 65), (152, 65)), + ((32, 32), (32, 32), False), + ((2, 35), (6, 35), False), + ((3, 3), (68, 3), False), + ((79, 65), (152, 65), True), ], ) @pytest.mark.parametrize("in_dtype", ["float32", "float16"]) -def test_sme_dense(data_shape, weight_shape, in_dtype): +def test_sme_dense(data_shape, weight_shape, enable_bias, in_dtype): np.random.seed(0) out_dtype = "float32" @@ -117,8 +117,14 @@ def test_sme_dense(data_shape, weight_shape, in_dtype): weight_data = np.random.uniform(size=weight_shape).astype(in_dtype) weight = relay.const(weight_data, dtype=in_dtype) - dense = relay.nn.dense(inp, weight, out_dtype=out_dtype) - func = relay.Function(relay.analysis.free_vars(dense), dense) + relay_op = relay.nn.dense(inp, weight, out_dtype=out_dtype) + + if enable_bias: + bias_data = np.random.uniform(size=weight_shape[0]).astype(out_dtype) + bias = relay.const(bias_data, dtype=out_dtype) + relay_op = relay.nn.bias_add(relay_op, bias) + + func = relay.Function(relay.analysis.free_vars(relay_op), relay_op) ir_mod = tvm.IRModule.from_expr(func) ir_mod = tvm.relay.transform.InferType()(ir_mod) @@ -147,8 +153,10 @@ def test_sme_dense(data_shape, weight_shape, in_dtype): runtime=runtime, params=params, ) + + bias_postfix = "_add" if enable_bias else "" generated_func = executor_factory.lowered_ir_mods.items()[0][1][ - "tvmgen_default_fused_nn_matmul" + f"tvmgen_default_fused_nn_matmul{bias_postfix}" ] extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)