Skip to content

Commit 1a2a404

Browse files
committed
[experimental][kleidi] Add a basic test - compiles
1 parent 78ce74e commit 1a2a404

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ add_library(
2929
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
3030
)
3131

32+
if(NOT TORCHAO_INCLUDE_DIRS)
33+
set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES})
34+
endif()
35+
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)
36+
3237
enable_testing()
3338

3439
add_executable(test_quantization test_quantization.cpp)
@@ -61,6 +66,7 @@ target_link_libraries(
6166
PRIVATE
6267
GTest::gtest_main
6368
dep
69+
torchao_kernels_aarch64
6470
)
6571

6672
add_executable(test_valpacking test_valpacking.cpp)

torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
1313
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
14+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
1415
#include <vector>
1516

1617
float kTol = 0.0001;
@@ -350,4 +351,75 @@ TEST(
350351
}
351352
}
352353

354+
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
355+
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
356+
int m,
357+
int k,
358+
int n,
359+
int group_size) {
360+
auto test_case = torchao::
361+
channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate(
362+
m,
363+
k,
364+
n,
365+
group_size,
366+
weight_nbit,
367+
has_weight_zeros,
368+
has_bias,
369+
has_clamp);
370+
371+
using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
372+
373+
std::vector<char> activation_data(
374+
activation_data_size(m, k, group_size));
375+
376+
prepare_activation_data(
377+
(void*)activation_data.data(),
378+
m,
379+
k,
380+
group_size,
381+
test_case.activations.data());
382+
383+
std::vector<char> weight_data(
384+
weight_data_size(n, k, group_size));
385+
386+
prepare_weight_data(
387+
(void*)weight_data.data(),
388+
n,
389+
k,
390+
group_size,
391+
test_case.weight_qvals.data(),
392+
test_case.weight_scales.data(),
393+
/*weight_zeros=*/test_case.weight_zeros.data());
394+
395+
std::vector<float> output(m * n);
396+
kernel(
397+
output.data(),
398+
/*output_m_stride=*/n,
399+
m,
400+
n,
401+
k,
402+
group_size,
403+
weight_data.data(),
404+
activation_data.data(),
405+
/*bias=*/test_case.bias.data(),
406+
/*clamp_min=*/test_case.clamp_min,
407+
/*clamp_max=*/test_case.clamp_max);
408+
409+
for (int i = 0; i < m * n; i++) {
410+
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
411+
}
412+
}
413+
414+
TEST(
415+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
416+
only_supported) {
417+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
418+
4 /*weight_nbit*/,
419+
false /*has_weight_zeros*/,
420+
false /*has_bias*/,
421+
false /*has_clamp*/>(
422+
/*m=*/16, /*k=*/64, /*n=*/16, /*group_size=*/32);
423+
}
424+
353425
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)