88#include  < benchmark/benchmark.h> 
99
1010#include  < torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h> 
11+ #include  < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h> 
1112#include  < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h> 
1213#include  < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h> 
1314#include  < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h> 
1617
1718namespace  {
1819
20+ //  Benchmark utility to compare variants of uint1 packing
21+ void  pack_uint1_values (
22+     uint8_t * packed,
23+     uint8_t * unpacked,
24+     int  packed_size,
25+     int  unpacked_size,
26+     int  variant) {
27+   constexpr  int  nbit = 1 ;
28+   constexpr  int  bitsPerByte = 8 ;
29+   assert (unpacked_size * nbit / bitsPerByte == packed_size);
30+   assert (packed_size % variant == 0 );
31+ 
32+   uint8x16_t  unpacked0;
33+   uint8x16_t  unpacked1;
34+   uint8x16_t  unpacked2;
35+   uint8x16_t  unpacked3;
36+   uint8x16_t  unpacked4;
37+   uint8x16_t  unpacked5;
38+   uint8x16_t  unpacked6;
39+   uint8x16_t  unpacked7;
40+ 
41+   switch  (variant) {
42+     case  8 :
43+       for  (int  i = 0 ; i < unpacked_size; i += 8 ) {
44+         torchao::bitpacking::internal::pack_8_uint1_values (
45+             packed + ((i * nbit) / bitsPerByte), unpacked + i);
46+       }
47+       break ;
48+     case  64 :
49+       for  (int  i = 0 ; i < unpacked_size; i += 64 ) {
50+         torchao::bitpacking::internal::vec_load_64_uint8_values (
51+             unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
52+         torchao::bitpacking::internal::vec_pack_64_uint1_values (
53+             packed + ((i * nbit) / bitsPerByte),
54+             unpacked0,
55+             unpacked1,
56+             unpacked2,
57+             unpacked3);
58+       }
59+       break ;
60+     case  128 :
61+       for  (int  i = 0 ; i < unpacked_size; i += 128 ) {
62+         torchao::bitpacking::internal::vec_load_64_uint8_values (
63+             unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
64+         torchao::bitpacking::internal::vec_load_64_uint8_values (
65+             unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64 );
66+         torchao::bitpacking::internal::vec_pack_128_uint1_values (
67+             packed + ((i * nbit) / bitsPerByte),
68+             unpacked0,
69+             unpacked1,
70+             unpacked2,
71+             unpacked3,
72+             unpacked4,
73+             unpacked5,
74+             unpacked6,
75+             unpacked7);
76+       }
77+       break ;
78+   }
79+ }
80+ 
81+ //  Benchmark utility to compare variants of uint1 packing
82+ void  unpack_uint1_values (
83+     uint8_t * unpacked,
84+     uint8_t * packed,
85+     int  unpacked_size,
86+     int  packed_size,
87+     int  variant) {
88+   constexpr  int  nbit = 1 ;
89+   constexpr  int  bitsPerByte = 8 ;
90+   assert (unpacked_size * nbit / bitsPerByte == packed_size);
91+   assert (packed_size % variant == 0 );
92+ 
93+   uint8x16_t  unpacked0;
94+   uint8x16_t  unpacked1;
95+   uint8x16_t  unpacked2;
96+   uint8x16_t  unpacked3;
97+   uint8x16_t  unpacked4;
98+   uint8x16_t  unpacked5;
99+   uint8x16_t  unpacked6;
100+   uint8x16_t  unpacked7;
101+ 
102+   switch  (variant) {
103+     case  8 :
104+       for  (int  i = 0 ; i < unpacked_size; i += 8 ) {
105+         torchao::bitpacking::internal::unpack_8_uint1_values (
106+             unpacked + i, packed + ((i * nbit) / bitsPerByte));
107+       }
108+       break ;
109+     case  64 :
110+       for  (int  i = 0 ; i < unpacked_size; i += 64 ) {
111+         torchao::bitpacking::internal::vec_unpack_64_uint1_values (
112+             unpacked0,
113+             unpacked1,
114+             unpacked2,
115+             unpacked3,
116+             packed + ((i * nbit) / bitsPerByte));
117+         torchao::bitpacking::internal::vec_store_64_uint8_values (
118+             unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
119+       }
120+       break ;
121+     case  128 :
122+       for  (int  i = 0 ; i < unpacked_size; i += 128 ) {
123+         torchao::bitpacking::internal::vec_unpack_128_uint1_values (
124+             unpacked0,
125+             unpacked1,
126+             unpacked2,
127+             unpacked3,
128+             unpacked4,
129+             unpacked5,
130+             unpacked6,
131+             unpacked7,
132+             packed + ((i * nbit) / bitsPerByte));
133+         torchao::bitpacking::internal::vec_store_64_uint8_values (
134+             unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
135+         torchao::bitpacking::internal::vec_store_64_uint8_values (
136+             unpacked + i + 64 , unpacked4, unpacked5, unpacked6, unpacked7);
137+       }
138+       break ;
139+   }
140+ }
141+ 
19142//  Benchmark utility to compare variants of uint2 packing
20143void  pack_uint2_values (
21144    uint8_t * packed,
@@ -470,6 +593,44 @@ void unpack_uint5_values(
470593
471594} //  namespace
472595
596+ static  void  benchmark_pack_uint1_values (benchmark::State& state) {
597+   int  unpacked_size = state.range (0 );
598+   int  variant = state.range (1 );
599+   int  nbit = 1 ;
600+ 
601+   assert (unpacked_size % 8  == 0 );
602+   int  packed_size = (unpacked_size / 8 ) * nbit;
603+ 
604+   auto  packed = std::vector<uint8_t >(packed_size, 0 );
605+   auto  unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit);
606+ 
607+   for  (auto  _ : state) {
608+     pack_uint1_values (
609+         packed.data (), unpacked.data (), packed_size, unpacked_size, variant);
610+   }
611+ }
612+ 
613+ static  void  benchmark_unpack_uint1_values (benchmark::State& state) {
614+   int  unpacked_size = state.range (0 );
615+   int  variant = state.range (1 );
616+   int  nbit = 1 ;
617+ 
618+   assert (unpacked_size % 8  == 0 );
619+   int  packed_size = (unpacked_size / 8 ) * nbit;
620+ 
621+   auto  packed = torchao::get_random_lowbit_vector (packed_size, 8 );
622+   auto  unpacked = std::vector<uint8_t >(unpacked_size, 0 );
623+ 
624+   for  (auto  _ : state) {
625+     unpack_uint1_values (
626+         unpacked.data (),
627+         packed.data (),
628+         unpacked.size (),
629+         packed.size (),
630+         variant);
631+   }
632+ }
633+ 
473634static  void  benchmark_pack_uint2_values (benchmark::State& state) {
474635  int  unpacked_size = state.range (0 );
475636  int  variant = state.range (1 );
@@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
478639  assert (unpacked_size % 8  == 0 );
479640  int  packed_size = (unpacked_size / 8 ) * nbit;
480641
481-   auto  packed = std::vector<uint8_t >(unpacked_size , 0 );
482-   auto  unpacked = torchao::get_random_lowbit_vector (packed_size,  8 );
642+   auto  packed = std::vector<uint8_t >(packed_size , 0 );
643+   auto  unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
483644
484645  for  (auto  _ : state) {
485646    pack_uint2_values (
@@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
516677  assert (unpacked_size % 8  == 0 );
517678  int  packed_size = (unpacked_size / 8 ) * nbit;
518679
519-   auto  packed = std::vector<uint8_t >(unpacked_size , 0 );
520-   auto  unpacked = torchao::get_random_lowbit_vector (packed_size,  8 );
680+   auto  packed = std::vector<uint8_t >(packed_size , 0 );
681+   auto  unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
521682
522683  for  (auto  _ : state) {
523684    pack_uint3_values (
@@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
554715  assert (unpacked_size % 8  == 0 );
555716  int  packed_size = (unpacked_size / 8 ) * nbit;
556717
557-   auto  packed = std::vector<uint8_t >(unpacked_size , 0 );
558-   auto  unpacked = torchao::get_random_lowbit_vector (packed_size,  8 );
718+   auto  packed = std::vector<uint8_t >(packed_size , 0 );
719+   auto  unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
559720
560721  for  (auto  _ : state) {
561722    pack_uint4_values (
@@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
592753  assert (unpacked_size % 8  == 0 );
593754  int  packed_size = (unpacked_size / 8 ) * nbit;
594755
595-   auto  packed = std::vector<uint8_t >(unpacked_size , 0 );
596-   auto  unpacked = torchao::get_random_lowbit_vector (packed_size,  8 );
756+   auto  packed = std::vector<uint8_t >(packed_size , 0 );
757+   auto  unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
597758
598759  for  (auto  _ : state) {
599760    pack_uint5_values (
@@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
622783  }
623784}
624785
786+ BENCHMARK (benchmark_pack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
787+ BENCHMARK (benchmark_unpack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
625788BENCHMARK (benchmark_pack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
626789BENCHMARK (benchmark_unpack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
627790BENCHMARK (benchmark_pack_uint3_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
0 commit comments