9
9
#if defined(__aarch64__) || defined(__ARM_NEON)
10
10
11
11
#include < arm_neon.h>
12
- #include < torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
13
12
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
14
13
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
15
14
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
16
15
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
17
16
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
18
17
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
18
+ #include < torchao/experimental/kernels/cpu/aarch64/macro.h>
19
19
#include < cassert>
20
20
21
21
namespace torchao {
@@ -142,7 +142,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
142
142
break ;
143
143
case 6 :
144
144
torchao::bitpacking::internal::vec_pack_32_uint6_values (
145
- packed, shifted0, shifted1);
145
+ packed, shifted0, shifted1);
146
146
break ;
147
147
default :
148
148
assert (false );
@@ -153,7 +153,7 @@ template <int nbit>
153
153
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values (
154
154
int8x16_t & unpacked0,
155
155
int8x16_t & unpacked1,
156
- uint8_t * packed) {
156
+ const uint8_t * packed) {
157
157
static_assert (nbit < 8 );
158
158
static_assert (nbit >= 1 );
159
159
@@ -217,7 +217,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
217
217
break ;
218
218
case 6 :
219
219
torchao::bitpacking::internal::vec_unpack_32_uint6_values (
220
- shifted0, shifted1, packed);
220
+ shifted0, shifted1, packed);
221
221
break ;
222
222
default :
223
223
assert (false );
@@ -288,7 +288,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
288
288
int8x16_t & unpacked1,
289
289
int8x16_t & unpacked2,
290
290
int8x16_t & unpacked3,
291
- uint8_t * packed) {
291
+ const uint8_t * packed) {
292
292
static_assert (nbit < 8 );
293
293
static_assert (nbit >= 1 );
294
294
@@ -443,7 +443,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
443
443
int8x16_t & unpacked5,
444
444
int8x16_t & unpacked6,
445
445
int8x16_t & unpacked7,
446
- uint8_t * packed) {
446
+ const uint8_t * packed) {
447
447
static_assert (nbit < 8 );
448
448
static_assert (nbit >= 1 );
449
449
0 commit comments