diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h index 545628f344..15e36b81c4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h @@ -288,7 +288,7 @@ struct KernelImpl { constexpr int nr = 8; constexpr int kr = 8; assert(m % mr == 0); - assert(k % kr == 0); + assert(k % 16 == 0); assert(n >= nr); std::vector rhs_packed(n * k); // Since we are casting int8_t to float32_t in order to tranpose matrix in a diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 300ac8c442..18c9986393 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -53,7 +53,7 @@ struct test_channelwise_8bit_channelwise_8bit_b< const int, const int); kernel_fn_type kernel_fn = nullptr; - if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 8 == 0)) { + if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 16 == 0)) { using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot; kernel_fn = kernel; @@ -531,9 +531,6 @@ static void test_8bit_per_token_q_at_k_matmul_attention( channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case:: generate(b, s_q, s_k, h, d, transpose); - using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; - size_t q_b_stride = test_case.b_q_stride; size_t q_h_stride = test_case.h_q_stride; size_t q_s_q_stride = test_case.s_q_stride; @@ -553,9 +550,36 @@ static void test_8bit_per_token_q_at_k_matmul_attention( size_t output_h_stride = s_q * s_k; size_t output_s_q_stride = s_k; + using kernel_fn_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + kernel_fn_type kernel_fn = nullptr; + if ((s_q % 4 == 0) && (s_k % 8 == 0) && (d % 16 == 0)) { + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot; + kernel_fn = kernel; + } else { + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + kernel_fn = kernel; + } + for (int b_idx = 0; b_idx < b; b_idx++) { for (int h_idx = 0; h_idx < h; h_idx++) { - kernel( + kernel_fn( s_q, s_k, d, @@ -587,6 +611,14 @@ TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) { test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16); } +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernel) { + test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernelNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16, false); +} + TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) { test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33); }