Skip to content

Commit 3db13c0

Browse files
metascroyjainapurva
authored andcommitted
Enable 6-bit kernel
Differential Revision: D63991820 Pull Request resolved: #1027
1 parent 2761ed5 commit 3db13c0

File tree

5 files changed

+65
-12
lines changed

5 files changed

+65
-12
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ TORCH_LIBRARY(torchao, m) {
6666
DEFINE_OP(3);
6767
DEFINE_OP(4);
6868
DEFINE_OP(5);
69+
DEFINE_OP(6);
6970
}
7071

7172
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
@@ -74,6 +75,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
7475
DEFINE_CPU_IMPL(3);
7576
DEFINE_CPU_IMPL(4);
7677
DEFINE_CPU_IMPL(5);
78+
DEFINE_CPU_IMPL(6);
7779
}
7880

7981
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
@@ -82,4 +84,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
8284
DEFINE_META_IMPL(3);
8385
DEFINE_META_IMPL(4);
8486
DEFINE_META_IMPL(5);
87+
DEFINE_META_IMPL(6);
8588
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 6, /*has_weight_zeros*/ false>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit0zp_weight.out", _op_out);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 6, /*has_weight_zeros*/ true>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit_weight.out", _op_out);

torchao/experimental/quant_api.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,7 @@ def forward(self, x):
115115
lead_shape = x.shape[0:-2]
116116
m, k = x.shape[-2], x.shape[-1]
117117
n = self._n.shape[1]
118-
x = x.reshape(-1, m, k)
119-
120-
res = [
121-
self._linear_op(
122-
x[i, :, :], self.packed_weights, self._group_size, self._n, self._k
123-
)
124-
for i in range(x.shape[0])
125-
]
126-
res = torch.stack(res)
118+
res = self._linear_op(x.reshape(-1, k), self.packed_weights, self._group_size, self._n, self._k)
127119
res = res.reshape(*lead_shape, m, n)
128120
return res
129121

@@ -206,7 +198,7 @@ def forward(self, x):
206198

207199
def _maybe_get_quantized_linear_native(nbit, has_weight_zeros):
208200
try:
209-
if nbit in [1, 2, 3, 4, 5]:
201+
if nbit in [1, 2, 3, 4, 5, 6]:
210202
wzp_suffix = "" if has_weight_zeros else "0zp"
211203
return _Int8DynActIntxWeightQuantizedLinearNative(
212204
pack_weight_op=getattr(

torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_accuracy(self):
3636
m = 1
3737
n = 1071
3838
k = 4096
39-
activations = torch.randn(m, k, dtype=torch.float32)
39+
activations = torch.randn(2, 3, m, k, dtype=torch.float32)
4040
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
4141

4242
for nbit in [1, 2, 3, 4, 5, 6, 7]:
@@ -84,7 +84,7 @@ def test_export_compile_aoti(self):
8484
layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)]
8585
model = torch.nn.Sequential(*layers)
8686

87-
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
87+
activations = torch.randn(m, k0, dtype=torch.float32)
8888

8989
print("Quantizing model")
9090
quantizer = Int8DynActIntxWeightQuantizer(

0 commit comments

Comments
 (0)