Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <torchao/experimental/kernels/cpu/aarch64/packing/utils.h>
#include <array>
#include <cstring>

Expand Down Expand Up @@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer(
assert(false);
}

// Packs nr * kr values for GEMM with packing params (nr, kr, sr)
// It takes (kr / sr) values from each of nr columns and writes to packed_values
// This is repeated sr times
template <typename T>
void pack_values(
// Output
T* packed_values,
// Inputs
const T* values,
int nr,
int kr,
int sr) {
assert(kr % sr == 0);
int kr_per_sr = kr / sr;
int dst_idx = 0;
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
for (int n_idx = 0; n_idx < nr; n_idx++) {
// Take kr_per_sr values from column n_idx
std::memcpy(
packed_values + dst_idx,
values + n_idx * kr + sr_idx * kr_per_sr,
sizeof(T) * kr_per_sr);
dst_idx += kr_per_sr;
}
}
}

// Undoes pack_values
template <typename T>
void unpack_values(
// Output
T* values,
// Inputs
const T* packed_values,
int nr,
int kr,
int sr) {
// packed_values and values should have size nr * kr
// This function takes (kr / sr) from each column of nr columns and writes to
// output This is repeated sr times
assert(kr % sr == 0);
int kr_per_sr = kr / sr;
int dst_idx = 0;
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
for (int n_idx = 0; n_idx < nr; n_idx++) {
// Take kr_per_sr values from column n_idx
std::memcpy(
values + n_idx * kr + sr_idx * kr_per_sr,
packed_values + dst_idx,
sizeof(T) * kr_per_sr);
dst_idx += kr_per_sr;
}
}
}

// Size in bytes of 1 packed weights column
size_t inline packed_weights_size_per_n(
int k,
Expand Down Expand Up @@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl(
}

// Pack buffer
internal::pack_values(packed_values, buffer.data(), nr, kr, sr);
torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr);
if constexpr (has_lut) {
internal::pack_buffer_for_lut<weight_nbit, kr, nr>(
packed_weights_byte_ptr, packed_values);
Expand Down Expand Up @@ -498,7 +444,7 @@ void unpack_weights_at_n_idx(
internal::unpack_buffer<weight_nbit, kr, nr>(
packed_values, packed_weights_byte_ptr);
packed_weights_byte_ptr += packed_buffer_bytes;
internal::unpack_values(buffer.data(), packed_values, nr, kr, sr);
torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr);

// Write weight_qvals
for (int j = 0; j < nr; j++) {
Expand Down
67 changes: 67 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/packing/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <cassert>
#include <cstring>

namespace torchao::packing {

// Packs nr * kr values for GEMM with packing params (nr, kr, sr)
// It takes (kr / sr) values from each of nr columns and writes to packed_values
// This is repeated sr times
template <typename T>
void pack_values(
// Output
T* packed_values,
// Inputs
const T* values,
int nr,
int kr,
int sr) {
assert(kr % sr == 0);
int kr_per_sr = kr / sr;
int dst_idx = 0;
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
for (int n_idx = 0; n_idx < nr; n_idx++) {
// Take kr_per_sr values from column n_idx
std::memcpy(
packed_values + dst_idx,
values + n_idx * kr + sr_idx * kr_per_sr,
sizeof(T) * kr_per_sr);
dst_idx += kr_per_sr;
}
}
}

// Undoes pack_values
template <typename T>
void unpack_values(
// Output
T* values,
// Inputs
const T* packed_values,
int nr,
int kr,
int sr) {
// packed_values and values should have size nr * kr
// This function takes (kr / sr) from each column of nr columns and writes to
// output This is repeated sr times
assert(kr % sr == 0);
int kr_per_sr = kr / sr;
int dst_idx = 0;
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
for (int n_idx = 0; n_idx < nr; n_idx++) {
// Take kr_per_sr values from column n_idx
std::memcpy(
values + n_idx * kr + sr_idx * kr_per_sr,
packed_values + dst_idx,
sizeof(T) * kr_per_sr);
dst_idx += kr_per_sr;
}
}
}

} // namespace torchao::packing
Loading