|
5 | 5 | #include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h> |
6 | 6 | #include <torchao/experimental/kernels/cpu/aarch64/macro.h> |
7 | 7 | #include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h> |
| 8 | +#include <torchao/experimental/kernels/cpu/aarch64/packing/utils.h> |
8 | 9 | #include <array> |
9 | 10 | #include <cstring> |
10 | 11 |
|
@@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer( |
125 | 126 | assert(false); |
126 | 127 | } |
127 | 128 |
|
128 | | -// Packs nr * kr values for GEMM with packing params (nr, kr, sr) |
129 | | -// It takes (kr / sr) values from each of nr columns and writes to packed_values |
130 | | -// This is repeated sr times |
131 | | -template <typename T> |
132 | | -void pack_values( |
133 | | - // Output |
134 | | - T* packed_values, |
135 | | - // Inputs |
136 | | - const T* values, |
137 | | - int nr, |
138 | | - int kr, |
139 | | - int sr) { |
140 | | - assert(kr % sr == 0); |
141 | | - int kr_per_sr = kr / sr; |
142 | | - int dst_idx = 0; |
143 | | - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { |
144 | | - for (int n_idx = 0; n_idx < nr; n_idx++) { |
145 | | - // Take kr_per_sr values from column n_idx |
146 | | - std::memcpy( |
147 | | - packed_values + dst_idx, |
148 | | - values + n_idx * kr + sr_idx * kr_per_sr, |
149 | | - sizeof(T) * kr_per_sr); |
150 | | - dst_idx += kr_per_sr; |
151 | | - } |
152 | | - } |
153 | | -} |
154 | | - |
155 | | -// Undoes pack_values |
156 | | -template <typename T> |
157 | | -void unpack_values( |
158 | | - // Output |
159 | | - T* values, |
160 | | - // Inputs |
161 | | - const T* packed_values, |
162 | | - int nr, |
163 | | - int kr, |
164 | | - int sr) { |
165 | | - // packed_values and values should have size nr * kr |
166 | | - // This function takes (kr / sr) from each column of nr columns and writes to |
167 | | - // output This is repeated sr times |
168 | | - assert(kr % sr == 0); |
169 | | - int kr_per_sr = kr / sr; |
170 | | - int dst_idx = 0; |
171 | | - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { |
172 | | - for (int n_idx = 0; n_idx < nr; n_idx++) { |
173 | | - // Take kr_per_sr values from column n_idx |
174 | | - std::memcpy( |
175 | | - values + n_idx * kr + sr_idx * kr_per_sr, |
176 | | - packed_values + dst_idx, |
177 | | - sizeof(T) * kr_per_sr); |
178 | | - dst_idx += kr_per_sr; |
179 | | - } |
180 | | - } |
181 | | -} |
182 | | - |
183 | 129 | // Size in bytes of 1 packed weights column |
184 | 130 | size_t inline packed_weights_size_per_n( |
185 | 131 | int k, |
@@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl( |
344 | 290 | } |
345 | 291 |
|
346 | 292 | // Pack buffer |
347 | | - internal::pack_values(packed_values, buffer.data(), nr, kr, sr); |
| 293 | + torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr); |
348 | 294 | if constexpr (has_lut) { |
349 | 295 | internal::pack_buffer_for_lut<weight_nbit, kr, nr>( |
350 | 296 | packed_weights_byte_ptr, packed_values); |
@@ -498,7 +444,7 @@ void unpack_weights_at_n_idx( |
498 | 444 | internal::unpack_buffer<weight_nbit, kr, nr>( |
499 | 445 | packed_values, packed_weights_byte_ptr); |
500 | 446 | packed_weights_byte_ptr += packed_buffer_bytes; |
501 | | - internal::unpack_values(buffer.data(), packed_values, nr, kr, sr); |
| 447 | + torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr); |
502 | 448 |
|
503 | 449 | // Write weight_qvals |
504 | 450 | for (int j = 0; j < nr; j++) { |
|
0 commit comments