Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ struct KernelImpl<true, true, false, true> {
constexpr int nr = 8;
constexpr int kr = 8;
assert(m % mr == 0);
assert(k % kr == 0);
assert(k % 16 == 0);
assert(n >= nr);
std::vector<int8_t> rhs_packed(n * k);
// Since we are casting int8_t to float32_t in order to tranpose matrix in a
Expand Down
42 changes: 37 additions & 5 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct test_channelwise_8bit_channelwise_8bit_b<
const int,
const int);
kernel_fn_type kernel_fn = nullptr;
if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 8 == 0)) {
if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 16 == 0)) {
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
kernel_fn = kernel<a_has_zeros, b_has_zeros, false, true>;
Expand Down Expand Up @@ -531,9 +531,6 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case::
generate(b, s_q, s_k, h, d, transpose);

using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;

size_t q_b_stride = test_case.b_q_stride;
size_t q_h_stride = test_case.h_q_stride;
size_t q_s_q_stride = test_case.s_q_stride;
Expand All @@ -553,9 +550,36 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
size_t output_h_stride = s_q * s_k;
size_t output_s_q_stride = s_k;

using kernel_fn_type = void (*)(
int,
int,
int,
const void*,
int,
const void*,
int,
float*,
int,
const int8_t*,
const int8_t*,
const float*,
const float*,
const int,
const int);
kernel_fn_type kernel_fn = nullptr;
if ((s_q % 4 == 0) && (s_k % 8 == 0) && (d % 16 == 0)) {
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
kernel_fn = kernel<true, true, false, true>;
} else {
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
kernel_fn = kernel<true, true, false, true>;
}

for (int b_idx = 0; b_idx < b; b_idx++) {
for (int h_idx = 0; h_idx < h; h_idx++) {
kernel<true, true, false, true>(
kernel_fn(
s_q,
s_k,
d,
Expand Down Expand Up @@ -587,6 +611,14 @@ TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) {
test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16);
}

TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernel) {
test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16);
}

TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernelNoTranspose) {
test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16, false);
}

TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) {
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33);
}
Expand Down
Loading