99#if defined(__aarch64__) || defined(__ARM_NEON)
1010
1111#include < arm_neon.h>
12- #include < torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
1312#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
1413#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
1514#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
1615#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
1716#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
1817#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
18+ #include < torchao/experimental/kernels/cpu/aarch64/macro.h>
1919#include < cassert>
2020
2121namespace torchao {
@@ -142,7 +142,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
142142 break ;
143143 case 6 :
144144 torchao::bitpacking::internal::vec_pack_32_uint6_values (
145- packed, shifted0, shifted1);
145+ packed, shifted0, shifted1);
146146 break ;
147147 default :
148148 assert (false );
@@ -153,7 +153,7 @@ template <int nbit>
153153TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values (
154154 int8x16_t & unpacked0,
155155 int8x16_t & unpacked1,
156- uint8_t * packed) {
156+ const uint8_t * packed) {
157157 static_assert (nbit < 8 );
158158 static_assert (nbit >= 1 );
159159
@@ -217,7 +217,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
217217 break ;
218218 case 6 :
219219 torchao::bitpacking::internal::vec_unpack_32_uint6_values (
220- shifted0, shifted1, packed);
220+ shifted0, shifted1, packed);
221221 break ;
222222 default :
223223 assert (false );
@@ -288,7 +288,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
288288 int8x16_t & unpacked1,
289289 int8x16_t & unpacked2,
290290 int8x16_t & unpacked3,
291- uint8_t * packed) {
291+ const uint8_t * packed) {
292292 static_assert (nbit < 8 );
293293 static_assert (nbit >= 1 );
294294
@@ -443,7 +443,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
443443 int8x16_t & unpacked5,
444444 int8x16_t & unpacked6,
445445 int8x16_t & unpacked7,
446- uint8_t * packed) {
446+ const uint8_t * packed) {
447447 static_assert (nbit < 8 );
448448 static_assert (nbit >= 1 );
449449
0 commit comments