Skip to content

Commit a898df9

Browse files
committed
Move Uintx out of prototype for future extension
Summary: Thanks @vayuda for adding the initial version of Uintx tensor subclass we can now integrate this with `torch.uint1` to `torch.uint7` dtypes with some helpers to unblock the benefit of bitpacking (model size saving) to people first, and then we can gradually optimize the performance. Also executorch is planning to integrate their low bit kernels with us, more native experience with these lower bit types will be required / useful there as well Test Plan: python test/dtypes/test_uintx.py Reviewers: Subscribers: Tasks: Tags:
1 parent 1cfe69e commit a898df9

File tree

7 files changed

+142
-148
lines changed

7 files changed

+142
-148
lines changed
Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,46 @@
11
import torch
2-
from torchao.prototype.uintx import pack, unpack, pack_cpu, unpack_cpu
2+
from torchao.dtypes.uintx.bitpacking import pack, unpack, pack_cpu, unpack_cpu
33
import pytest
44
from torch.utils._triton import has_triton
55

6-
element_bit_width = (1,2,3,4,5,6,7)
6+
bit_widths = (1,2,3,4,5,6,7)
77
dimensions = (0, -1, 1)
88

99
@pytest.fixture(autouse=True)
1010
def run_before_and_after_tests():
1111
yield
1212
torch._dynamo.reset() # reset cache between tests
1313

14-
@pytest.mark.parametrize("element_bit_width", element_bit_width)
14+
@pytest.mark.parametrize("bit_width", bit_widths)
1515
@pytest.mark.parametrize("dim", dimensions)
16-
def test_CPU(element_bit_width, dim):
17-
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu')
18-
packed = pack_cpu(test_tensor, element_bit_width, dim = dim)
19-
unpacked = unpack_cpu(packed, element_bit_width, dim = dim)
16+
def test_CPU(bit_width, dim):
17+
test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8, device='cpu')
18+
packed = pack_cpu(test_tensor, bit_width, dim = dim)
19+
unpacked = unpack_cpu(packed, bit_width, dim = dim)
2020
assert(unpacked.allclose(test_tensor))
2121

2222

23-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
24-
@pytest.mark.parametrize("element_bit_width", element_bit_width)
23+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
24+
@pytest.mark.parametrize("bit_width", bit_widths)
2525
@pytest.mark.parametrize("dim", dimensions)
26-
def test_GPU(element_bit_width, dim):
27-
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
28-
packed = pack(test_tensor, element_bit_width, dim = dim)
29-
unpacked = unpack(packed, element_bit_width, dim = dim)
26+
def test_GPU(bit_width, dim):
27+
test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda()
28+
packed = pack(test_tensor, bit_width, dim = dim)
29+
unpacked = unpack(packed, bit_width, dim = dim)
3030
assert(unpacked.allclose(test_tensor))
3131

3232

3333
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
3434
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
35-
@pytest.mark.parametrize("element_bit_width", element_bit_width)
35+
@pytest.mark.parametrize("bit_width", bit_widths)
3636
@pytest.mark.parametrize("dim", dimensions)
37-
def test_compile(element_bit_width, dim):
37+
def test_compile(bit_width, dim):
3838
torch._dynamo.config.specialize_int = True
3939
pack_compile = torch.compile(pack, fullgraph=True)
4040
unpack_compile = torch.compile(unpack, fullgraph=True)
41-
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
42-
packed = pack(test_tensor, element_bit_width, dim = dim)
43-
unpacked = unpack(packed, element_bit_width, dim = dim)
41+
test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda()
42+
packed = pack(test_tensor, bit_width, dim = dim)
43+
unpacked = unpack(packed, bit_width, dim = dim)
4444
assert(unpacked.allclose(test_tensor))
4545

4646
# these test cases are for the example pack walk through in the bitpacking.py file
@@ -62,5 +62,3 @@ def test_pack_example_CPU():
6262
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
6363
unpacked = unpack([shard_4, shard_2], 6)
6464
assert unpacked.allclose(test_tensor)
65-
66-

test/prototype/test_uintx.py renamed to test/dtypes/test_uintx.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,26 @@
44

55
import torch
66

7-
from torchao.prototype.uintx import uintx_affine_weight_only, to_uintx
8-
from torchao.quantization.quant_api import quantize_
7+
from torchao.dtypes.uintx.Uintx import to_uintx
8+
from torchao.quantization.quant_api import quantize_, uintx_weight_only
99
from torchao.utils import TORCH_VERSION_AFTER_2_5
1010

1111
from torchao.quantization.quant_primitives import (
12-
MappingType,
13-
ZeroPointDomain,
14-
choose_qparams_affine,
15-
quantize_affine,
16-
dequantize_affine,
17-
)
12+
MappingType,
13+
ZeroPointDomain,
14+
choose_qparams_affine,
15+
quantize_affine,
16+
dequantize_affine,
17+
)
1818

19-
bit_sizes = (1,2,3,4,5,6,7)
20-
group_sizes = [32,64,128]
19+
bit_widths = (1, 2, 3, 4, 5, 6, 7)
20+
group_sizes = [32, 64, 128]
2121
devices = ["cpu", "cuda"]
2222
@pytest.fixture(autouse=True)
2323
def run_before_and_after_tests():
2424
yield
2525
torch._dynamo.reset() # reset cache between tests
2626

27-
28-
2927
class Linear16(torch.nn.Module):
3028
def __init__(self, scale, device):
3129
super().__init__()
@@ -37,52 +35,52 @@ def __init__(self, scale, device):
3735

3836
def forward(self, x):
3937
return self.net(x)
40-
41-
@pytest.mark.parametrize("bit_size", bit_sizes)
38+
39+
@pytest.mark.parametrize("bit_width", bit_widths)
4240
@pytest.mark.parametrize("group_size", group_sizes)
4341
@pytest.mark.parametrize("device", devices)
44-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
42+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4543
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
46-
def test_uintx_affine_weight_only_model_quant(bit_size, group_size, device):
44+
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
4745
scale = 512
4846
fp16 = Linear16(scale, device)
49-
quantize_(fp16, uintx_affine_weight_only(bit_size, group_size=group_size))
47+
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
5048
uintx = torch.compile(fp16, fullgraph=True)
5149
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
5250
output = uintx.forward(test_input)
5351
assert output != None, "model quantization failed"
54-
55-
@pytest.mark.parametrize("bit_size", bit_sizes)
52+
53+
@pytest.mark.parametrize("bit_width", bit_widths)
5654
@pytest.mark.parametrize("group_size", group_sizes)
5755
@pytest.mark.parametrize("device", devices)
58-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
56+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5957
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
60-
def test_uintx_affine_weight_only_quant(bit_size, group_size, device):
61-
input_float = torch.randn((1,256), dtype=torch.float16, device = device)
58+
def test_uintx_weight_only_quant(bit_width, group_size, device):
59+
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
6260
mapping_type = MappingType.SYMMETRIC
6361
quant_min = 0
64-
quant_max = 2**bit_size - 1
62+
quant_max = 2 ** bit_width - 1
6563
eps = torch.finfo(torch.float32).eps
6664
zero_point_dtype = torch.int32
6765
zero_point_domain = ZeroPointDomain.INT
6866
target_dtype = torch.uint8
6967
block_size = (1, group_size)
70-
68+
7169
scale, zero_point = choose_qparams_affine(
72-
input_float, mapping_type, block_size,
73-
target_dtype, quant_min, quant_max, eps, torch.float32,
74-
zero_point_dtype, True, zero_point_domain
70+
input_float, mapping_type, block_size,
71+
target_dtype, quant_min, quant_max, eps, torch.float32,
72+
zero_point_dtype, True, zero_point_domain
7573
)
76-
74+
7775
aqt = quantize_affine(
7876
input_float, block_size, scale,
7977
zero_point, target_dtype,
8078
quant_min = quant_min,
8179
quant_max = quant_max,
8280
zero_point_domain = zero_point_domain
83-
)
84-
85-
q = to_uintx(aqt, bit_size, -1)
81+
)
82+
83+
q = to_uintx(aqt, bit_width, -1)
8684
assert q != None, "quantization failed"
8785
deqaunt = dequantize_affine(
8886
q, block_size, scale,

torchao/prototype/uintx/Uintx.py renamed to torchao/dtypes/uintx/Uintx.py

Lines changed: 34 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class UintxTensor(torch.Tensor):
2727
int4_shard (torch.Tensor): 4 bit packed shard
2828
int2_shard (torch.Tensor): 2 bit packed shard
2929
int1_shard (torch.Tensor): 1 bit packed shard
30-
bit_size (int): element size in bits
30+
bit_width (int): number of bits for each element
3131
pack_dim: (int) dimension to pack along
3232
"""
3333
bits_to_shard = {
@@ -43,71 +43,71 @@ def __new__(
4343
cls,
4444
shards: List[torch.Tensor],
4545
packed_shape: List[int],
46-
bit_size: int,
46+
bit_width: int,
4747
pack_dim: int = -1,
4848
):
4949
kwargs = {"device": shards[0].device}
5050
kwargs["device"] = shards[0].device
5151
kwargs["layout"] = shards[0].layout
5252
kwargs["requires_grad"] = False
5353
kwargs["dtype"] = torch.uint8
54-
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)
54+
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)
5555

5656
def __init__(
5757
self,
5858
shards: List[torch.Tensor],
5959
packed_shape: List[int],
60-
bit_size: int,
60+
bit_width: int,
6161
pack_dim: int = -1,
6262
):
63-
for i, attrib in enumerate(self.bits_to_shard[bit_size]):
63+
for i, attrib in enumerate(self.bits_to_shard[bit_width]):
6464
setattr(self, attrib, shards[i])
65-
65+
6666
self.packed_shape = packed_shape
67-
self.bit_size = bit_size
67+
self.bit_width = bit_width
6868
self.pack_dim = pack_dim
69-
69+
7070
def get_shards(self):
71-
return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_size]]
72-
71+
return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]]
72+
7373
def __repr__(self):
74-
return f"Int{self.bit_size}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_size, dim = self.pack_dim)})"
75-
74+
return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})"
75+
7676
def __tensor_flatten__(self):
77-
return self.__class__.bits_to_shard[self.bit_size], [self.packed_shape, self.bit_size, self.pack_dim]
78-
77+
return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim]
78+
7979
@classmethod
8080
def __tensor_unflatten__(
8181
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
8282
):
8383
shards = list(tensor_data_dict.values())
84-
packed_shape, bit_size, pack_dim = tensor_attributes
85-
return cls(shards, packed_shape, bit_size, pack_dim)
84+
packed_shape, bit_width, pack_dim = tensor_attributes
85+
return cls(shards, packed_shape, bit_width, pack_dim)
8686

8787
implements = classmethod(_implements)
8888
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
8989
__torch_function__ = classmethod(_dispatch__torch_function__)
9090

9191
def get_plain(self):
92-
return unpack(self.get_shards(), self.bit_size, dim = self.pack_dim)
93-
92+
return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)
93+
9494
# temporary until kernels on packed tensors are created
9595
def apply_transformation(self, fn):
9696
og = self.get_plain()
9797
new = fn(og)
98-
return self.from_uint8(new, self.bit_size, self.pack_dim)
99-
98+
return self.from_uint8(new, self.bit_width, self.pack_dim)
99+
100100
# temporary until kernels on packed tensors are created
101101
def apply_fn_to_shards(self, fn):
102102
new_shards = [fn(shard) for shard in self.get_shards()]
103-
return self.__class__(new_shards, self.packed_shape, self.bit_size, self.pack_dim)
104-
103+
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)
104+
105105
@classmethod
106-
def from_uint8(cls, int_data: torch.Tensor, bit_size, pack_dim: int = -1):
107-
shards = pack(int_data, bit_size, dim=pack_dim)
106+
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
107+
shards = pack(int_data, bit_width, dim=pack_dim)
108108
shape = list(int_data.shape)
109-
shape[pack_dim] = shape[pack_dim] * bit_size // 8
110-
return cls(shards, int_data.shape, bit_size, pack_dim)
109+
shape[pack_dim] = shape[pack_dim] * bit_width // 8
110+
return cls(shards, int_data.shape, bit_width, pack_dim)
111111

112112

113113
implements = UintxTensor.implements
@@ -118,19 +118,19 @@ def _(func, types, args, kwargs):
118118
return return_and_correct_aliasing(
119119
func, args, kwargs, args[0].apply_fn_to_shards(torch.detach)
120120
)
121-
121+
122122
@implements(aten.view.default)
123123
def _(func, types, args, kwargs):
124124
return return_and_correct_aliasing(
125125
func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:]))
126126
)
127-
127+
128128
@implements(aten._to_copy.default)
129129
def _(func, types, args, kwargs):
130130
return return_and_correct_aliasing(
131131
func, args, kwargs, args[0]
132132
)
133-
133+
134134
@implements(aten.sub.Tensor)
135135
def _(func, types, args, kwargs):
136136
return return_and_correct_aliasing(
@@ -147,18 +147,18 @@ def _(func, types, args, kwargs):
147147

148148
@dataclass(frozen=True)
149149
class UintxLayoutType(LayoutType):
150-
bit_size: int
150+
bit_width: int
151151
pack_dim: int = -1
152-
152+
153153
def post_process(self, input: torch.Tensor) -> torch.Tensor:
154-
return to_uintx(input, self.bit_size, self.pack_dim)
154+
return to_uintx(input, self.bit_width, self.pack_dim)
155155

156156
@register_layout_cls(UintxLayoutType)
157157
class UintxAQTLayout(PlainAQTLayout):
158-
158+
159159
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
160160
return self.int_data.get_plain(), self.scale, self.zero_point
161-
161+
162162
@classmethod
163163
def from_plain(
164164
cls,
@@ -169,39 +169,3 @@ def from_plain(
169169
):
170170
assert isinstance(layout_type, UintxLayoutType)
171171
return cls(int_data, scale, zero_point, layout_type)
172-
173-
174-
def uintx_affine_weight_only(bit_size, group_size=64, pack_dim=-1):
175-
"""
176-
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
177-
x is the number of bits specified by the `nbits` argument
178-
"""
179-
from torchao.quantization.quant_primitives import (
180-
MappingType,
181-
ZeroPointDomain,
182-
choose_qparams_affine,
183-
quantize_affine,
184-
dequantize_affine,
185-
)
186-
from torchao.dtypes import to_affine_quantized
187-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
188-
def apply_uintx_weight_only_quant(weight):
189-
190-
layout_type = UintxLayoutType(bit_size=bit_size, pack_dim=pack_dim)
191-
mapping_type = MappingType.ASYMMETRIC
192-
block_size = (1, group_size)
193-
quant_min = 0
194-
quant_max = 2**bit_size - 1
195-
eps = torch.finfo(torch.float32).eps
196-
zero_point_dtype = torch.int32
197-
zero_point_domain = ZeroPointDomain.INT
198-
199-
return to_affine_quantized(
200-
weight, mapping_type, block_size, torch.uint8,
201-
quant_min = quant_min, quant_max = quant_max,
202-
eps = eps, zero_point_dtype=zero_point_dtype,
203-
zero_point_domain=zero_point_domain,
204-
layout_type=layout_type,
205-
)
206-
207-
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant)

torchao/dtypes/uintx/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)