diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h index aece38b435..7412b795e7 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer( assert(false); } -// Packs nr * kr values for GEMM with packing params (nr, kr, sr) -// It takes (kr / sr) values from each of nr columns and writes to packed_values -// This is repeated sr times -template -void pack_values( - // Output - T* packed_values, - // Inputs - const T* values, - int nr, - int kr, - int sr) { - assert(kr % sr == 0); - int kr_per_sr = kr / sr; - int dst_idx = 0; - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { - for (int n_idx = 0; n_idx < nr; n_idx++) { - // Take kr_per_sr values from column n_idx - std::memcpy( - packed_values + dst_idx, - values + n_idx * kr + sr_idx * kr_per_sr, - sizeof(T) * kr_per_sr); - dst_idx += kr_per_sr; - } - } -} - -// Undoes pack_values -template -void unpack_values( - // Output - T* values, - // Inputs - const T* packed_values, - int nr, - int kr, - int sr) { - // packed_values and values should have size nr * kr - // This function takes (kr / sr) from each column of nr columns and writes to - // output This is repeated sr times - assert(kr % sr == 0); - int kr_per_sr = kr / sr; - int dst_idx = 0; - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { - for (int n_idx = 0; n_idx < nr; n_idx++) { - // Take kr_per_sr values from column n_idx - std::memcpy( - values + n_idx * kr + sr_idx * kr_per_sr, - packed_values + dst_idx, - sizeof(T) * kr_per_sr); - dst_idx += kr_per_sr; - } - } -} - // Size in bytes of 1 packed weights column size_t inline packed_weights_size_per_n( int k, @@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl( } // Pack buffer - internal::pack_values(packed_values, buffer.data(), nr, kr, sr); + torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr); if constexpr (has_lut) { internal::pack_buffer_for_lut( packed_weights_byte_ptr, packed_values); @@ -498,7 +444,7 @@ void unpack_weights_at_n_idx( internal::unpack_buffer( packed_values, packed_weights_byte_ptr); packed_weights_byte_ptr += packed_buffer_bytes; - internal::unpack_values(buffer.data(), packed_values, nr, kr, sr); + torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr); // Write weight_qvals for (int j = 0; j < nr; j++) { diff --git a/torchao/experimental/kernels/cpu/aarch64/packing/utils.h b/torchao/experimental/kernels/cpu/aarch64/packing/utils.h new file mode 100644 index 0000000000..32ee7000b9 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/packing/utils.h @@ -0,0 +1,67 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +namespace torchao::packing { + +// Packs nr * kr values for GEMM with packing params (nr, kr, sr) +// It takes (kr / sr) values from each of nr columns and writes to packed_values +// This is repeated sr times +template +void pack_values( + // Output + T* packed_values, + // Inputs + const T* values, + int nr, + int kr, + int sr) { + assert(kr % sr == 0); + int kr_per_sr = kr / sr; + int dst_idx = 0; + for (int sr_idx = 0; sr_idx < sr; sr_idx++) { + for (int n_idx = 0; n_idx < nr; n_idx++) { + // Take kr_per_sr values from column n_idx + std::memcpy( + packed_values + dst_idx, + values + n_idx * kr + sr_idx * kr_per_sr, + sizeof(T) * kr_per_sr); + dst_idx += kr_per_sr; + } + } +} + +// Undoes pack_values +template +void unpack_values( + // Output + T* values, + // Inputs + const T* packed_values, + int nr, + int kr, + int sr) { + // packed_values and values should have size nr * kr + // This function takes (kr / sr) from each column of nr columns and writes to + // output This is repeated sr times + assert(kr % sr == 0); + int kr_per_sr = kr / sr; + int dst_idx = 0; + for (int sr_idx = 0; sr_idx < sr; sr_idx++) { + for (int n_idx = 0; n_idx < nr; n_idx++) { + // Take kr_per_sr values from column n_idx + std::memcpy( + values + n_idx * kr + sr_idx * kr_per_sr, + packed_values + dst_idx, + sizeof(T) * kr_per_sr); + dst_idx += kr_per_sr; + } + } +} + +} // namespace torchao::packing