Skip to content

Commit 343f7d4

Browse files
committed
WIP unit-tests
passes but needs higher ktol of 0.05 vs. 0.0001
1 parent 1a2a404 commit 343f7d4

File tree

7 files changed

+353
-150
lines changed

7 files changed

+353
-150
lines changed

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include(FetchContent)
1111
# intelligence (AI) workloads tailored for Arm® CPUs.
1212
FetchContent_Declare(kleidiai
1313
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
14-
GIT_TAG main) # TODO: set a pin
14+
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31)
1515

1616
FetchContent_MakeAvailable(kleidiai)
1717

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// namespace example
2+
// Copyright (c) Meta Platforms, Inc. and affiliates.
3+
// All rights reserved.
4+
//
5+
// This source code is licensed under the license found in the
6+
// LICENSE file in the root directory of this source tree.
7+
8+
#pragma once
9+
#include <cassert>
10+
#include <cstddef>
11+
#include <limits>
12+
#include <vector>
13+
14+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>
15+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
16+
17+
#include <kai/kai_common.h>
18+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
19+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
20+
21+
namespace torchao::kernels::cpu::aarch64::kleidi {
22+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
23+
24+
using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
25+
26+
namespace neon_dotprod_1x4x32 {
27+
ukernel get_ukernel() {
28+
return ukernel{
29+
.get_m_step =
30+
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
31+
.get_n_step =
32+
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
33+
.get_mr =
34+
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
35+
.get_nr =
36+
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
37+
.get_kr =
38+
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
39+
.get_sr =
40+
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
41+
.get_lhs_packed_offset =
42+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
43+
.get_rhs_packed_offset =
44+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
45+
.get_dst_offset =
46+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
47+
.get_dst_size =
48+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
49+
.run_matmul =
50+
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
51+
}
52+
53+
size_t roundup(size_t a, size_t b) {
54+
return ((a + b - 1) / b) * b;
55+
}
56+
57+
int activation_data_size(int m, int k, int group_size) {
58+
auto ukernel = get_ukernel();
59+
auto lhs_packing = get_lhs_packing();
60+
return lhs_packing.get_lhs_packed_size(
61+
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
62+
}
63+
64+
void prepare_activation_data(
65+
void* activation_data,
66+
// Inputs
67+
int m,
68+
int k,
69+
// Ignored if has_weight_zeros = false
70+
int group_size,
71+
const float* activations) {
72+
auto ukernel = get_ukernel();
73+
auto lhs_pack = get_lhs_packing();
74+
75+
lhs_pack.run_lhs_pack(
76+
m,
77+
k,
78+
ukernel.get_mr(),
79+
ukernel.get_kr(),
80+
ukernel.get_sr(),
81+
/*m_index_start=*/0,
82+
activations,
83+
/*lhs_stride=*/k * sizeof(float),
84+
activation_data);
85+
}
86+
87+
int weight_data_size(int n, int k, int group_size) {
88+
auto ukernel = get_ukernel();
89+
auto rhs_pack = get_rhs_packing();
90+
return rhs_pack.get_rhs_packed_size(
91+
n,
92+
k,
93+
ukernel.get_nr(),
94+
ukernel.get_kr(),
95+
ukernel.get_sr(),
96+
group_size,
97+
kai_datatype::kai_dt_bf16);
98+
}
99+
100+
inline uint16_t get_bf16_from_float(float f) {
101+
uint16_t bf16;
102+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
103+
memcpy(&bf16, &f, sizeof(uint16_t));
104+
#else
105+
const void* fp = reinterpret_cast<const void*>(
106+
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t));
107+
memcpy(&bf16, fp, sizeof(uint16_t));
108+
#endif // __BYTE_ORDER__
109+
return bf16;
110+
}
111+
112+
// TODO: move most of these functions in the parent namespace and take in
113+
// ukernel as a parameter
114+
void prepare_weight_data(
115+
void* weight_data,
116+
// Inputs
117+
int n,
118+
int k,
119+
int group_size,
120+
const int8_t* weight_qvals,
121+
const float* weight_scales,
122+
const int8_t* weight_zeros) {
123+
// TODO - remove this constraint and pad when possible
124+
assert(n % 2 == 0);
125+
126+
assert(group_size % 32 == 0);
127+
assert(k % group_size == 0);
128+
129+
// Convert scales to bf16
130+
// TODO SIMDify this
131+
size_t n_groups = n * k / group_size;
132+
auto weight_scales_bf16 = std::vector<uint16_t>(n_groups, 0);
133+
for (size_t i = 0; i < n_groups; i++) {
134+
assert(weight_zeros[i] == 0);
135+
weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]);
136+
}
137+
138+
// Prepack weights before packing
139+
// TODO SIMDify this
140+
auto packed_weight_qvals = std::vector<uint8_t>(n * k / 2, 255);
141+
uint8_t wzp = 8;
142+
for (size_t i = 0; i < n * k; i += 2) {
143+
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp);
144+
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp);
145+
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF));
146+
}
147+
148+
// Parameters for packing
149+
rhs_packing::qparams_t qparams{
150+
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16};
151+
152+
auto ukernel = get_ukernel();
153+
auto rhs_pack = get_rhs_packing();
154+
155+
rhs_pack.run_rhs_pack(
156+
/*groups=*/1,
157+
n,
158+
k,
159+
ukernel.get_nr(),
160+
ukernel.get_kr(),
161+
ukernel.get_sr(),
162+
group_size,
163+
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
164+
/*rhs_stride=*/roundup(k, 2) / 2,
165+
/*bias=*/nullptr, // TODO fix APIs to move bias here
166+
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
167+
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size),
168+
/*rhs_packed=*/weight_data,
169+
/*extra_bytes=*/0,
170+
/*qparams=*/&qparams);
171+
}
172+
173+
void kernel(
174+
// Outputs
175+
float32_t* output,
176+
// Inputs
177+
int output_m_stride,
178+
int m,
179+
int n,
180+
int k,
181+
int group_size,
182+
const void* weight_data,
183+
const void* activation_data,
184+
// Not applied if nullptr
185+
const float* bias,
186+
// zeros if has_clamp = false
187+
float clamp_min,
188+
float clamp_max) {
189+
assert(output_m_stride == n);
190+
if (clamp_min == clamp_max && clamp_min == 0) {
191+
clamp_min = std::numeric_limits<float_t>::lowest();
192+
clamp_max = std::numeric_limits<float_t>::max();
193+
}
194+
auto ukernel = get_ukernel();
195+
ukernel.run_matmul(
196+
m,
197+
n,
198+
k,
199+
group_size,
200+
activation_data,
201+
weight_data,
202+
output,
203+
/*dst_stride_row=*/n * sizeof(float),
204+
/*dst_stride_col=*/sizeof(float),
205+
clamp_min,
206+
clamp_max);
207+
}
208+
209+
size_t get_alignement() {
210+
return 16;
211+
}
212+
} // namespace neon_dotprod_1x4x32
213+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
214+
} // namespace torchao::kernels::cpu::aarch64::kleidi
Lines changed: 4 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// namespace example
12
// Copyright (c) Meta Platforms, Inc. and affiliates.
23
// All rights reserved.
34
//
@@ -6,102 +7,12 @@
67

78
#pragma once
89

9-
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>
1010
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
11-
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
1211

1312
namespace torchao::kernels::cpu::aarch64::kleidi {
14-
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
13+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
1514

16-
using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
15+
using ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
1716

18-
namespace neon_dotprod_1x4x32 {
19-
ukernel get_ukernel() {
20-
return ukernel {
21-
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
22-
.get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
23-
.get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
24-
.get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
25-
.get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
26-
.get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
27-
.get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
28-
.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
29-
.get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
30-
.get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
31-
.run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod
32-
};
33-
}
34-
35-
int activation_data_size(int m, int k, int group_size) {
36-
auto ukernel = get_ukernel();
37-
auto lhs_packing = get_lhs_packing();
38-
return lhs_packing.get_lhs_packed_size(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
39-
}
40-
41-
void prepare_activation_data(
42-
void* activation_data,
43-
// Inputs
44-
int m,
45-
int k,
46-
// Ignored if has_weight_zeros = false
47-
int group_size,
48-
const float* activations) {
49-
auto ukernel = get_ukernel();
50-
auto lhs_pack = get_lhs_packing();
51-
lhs_pack.run_lhs_pack(m, k, group_size, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr(), /*m_index_start=*/0,
52-
activations, /*lhs_stride=*/ k*sizeof(float), activation_data);
53-
}
54-
55-
int weight_data_size(int n, int k, int group_size) {
56-
auto ukernel = get_ukernel();
57-
auto rhs_pack = get_rhs_packing();
58-
return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), group_size);
59-
}
60-
61-
void prepare_weight_data(
62-
void* weight_data,
63-
// Inputs
64-
int n,
65-
int k,
66-
int group_size,
67-
const int8_t* weight_qvals,
68-
const float* weight_scales,
69-
const int8_t* weight_zeros) {
70-
if (weight_zeros) {
71-
// TODO check all zeros
72-
assert (weight_zeros[0] == 8);
73-
}
74-
auto ukernel = get_ukernel();
75-
auto rhs_pack = get_rhs_packing();
76-
rhs_packing::qparams_t qparams{1, 8};
77-
// @nocommit - Unsigned hack, add a naive packing routine
78-
rhs_pack.run_rhs_pack(/*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(),
79-
group_size, reinterpret_cast<const uint8_t*>(weight_qvals), /*bias=*/nullptr, weight_data, /*extra_bytes=*/0, &qparams);
80-
}
81-
82-
void kernel(
83-
// Outputs
84-
float32_t* output,
85-
// Inputs
86-
int output_m_stride,
87-
int m,
88-
int n,
89-
int k,
90-
int group_size,
91-
const void* weight_data,
92-
const void* activation_data,
93-
// Not applied if nullptr
94-
const float* bias,
95-
// Ignored if has_clamp = false
96-
float clamp_min,
97-
float clamp_max) {
98-
auto ukernel = get_ukernel();
99-
ukernel.run_matmul(m, n, k, group_size, activation_data, weight_data, output, output_m_stride, /*dst_stride_col=*/1, clamp_min, clamp_max);
100-
}
101-
102-
size_t get_alignement() {
103-
return 16;
104-
}
105-
} // namespace neon_dotprod_1x4x32
106-
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
17+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
10718
} // namespace torchao::kernels::cpu::aarch64::kleidi

0 commit comments

Comments
 (0)