|
11 | 11 | #include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h> |
12 | 12 | #include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h> |
13 | 13 | #include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h> |
| 14 | +#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> |
14 | 15 | #include <vector> |
15 | 16 |
|
16 | 17 | float kTol = 0.0001; |
@@ -350,4 +351,75 @@ TEST( |
350 | 351 | } |
351 | 352 | } |
352 | 353 |
|
| 354 | +template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp> |
| 355 | +void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( |
| 356 | + int m, |
| 357 | + int k, |
| 358 | + int n, |
| 359 | + int group_size) { |
| 360 | + auto test_case = torchao:: |
| 361 | + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( |
| 362 | + m, |
| 363 | + k, |
| 364 | + n, |
| 365 | + group_size, |
| 366 | + weight_nbit, |
| 367 | + has_weight_zeros, |
| 368 | + has_bias, |
| 369 | + has_clamp); |
| 370 | + |
| 371 | + using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; |
| 372 | + |
| 373 | + std::vector<char> activation_data( |
| 374 | + activation_data_size(m, k, group_size)); |
| 375 | + |
| 376 | + prepare_activation_data( |
| 377 | + (void*)activation_data.data(), |
| 378 | + m, |
| 379 | + k, |
| 380 | + group_size, |
| 381 | + test_case.activations.data()); |
| 382 | + |
| 383 | + std::vector<char> weight_data( |
| 384 | + weight_data_size(n, k, group_size)); |
| 385 | + |
| 386 | + prepare_weight_data( |
| 387 | + (void*)weight_data.data(), |
| 388 | + n, |
| 389 | + k, |
| 390 | + group_size, |
| 391 | + test_case.weight_qvals.data(), |
| 392 | + test_case.weight_scales.data(), |
| 393 | + /*weight_zeros=*/test_case.weight_zeros.data()); |
| 394 | + |
| 395 | + std::vector<float> output(m * n); |
| 396 | + kernel( |
| 397 | + output.data(), |
| 398 | + /*output_m_stride=*/n, |
| 399 | + m, |
| 400 | + n, |
| 401 | + k, |
| 402 | + group_size, |
| 403 | + weight_data.data(), |
| 404 | + activation_data.data(), |
| 405 | + /*bias=*/test_case.bias.data(), |
| 406 | + /*clamp_min=*/test_case.clamp_min, |
| 407 | + /*clamp_max=*/test_case.clamp_max); |
| 408 | + |
| 409 | + for (int i = 0; i < m * n; i++) { |
| 410 | + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); |
| 411 | + } |
| 412 | +} |
| 413 | + |
| 414 | +TEST( |
| 415 | + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, |
| 416 | + only_supported) { |
| 417 | + test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< |
| 418 | + 4 /*weight_nbit*/, |
| 419 | + false /*has_weight_zeros*/, |
| 420 | + false /*has_bias*/, |
| 421 | + false /*has_clamp*/>( |
| 422 | + /*m=*/16, /*k=*/64, /*n=*/16, /*group_size=*/32); |
| 423 | +} |
| 424 | + |
353 | 425 | #endif // defined(__aarch64__) || defined(__ARM_NEON) |
0 commit comments