Skip to content

Commit 6f44d25

Browse files
authored
Minor upgrades to bit pack (#347)
* added dim=-1 and device is now based on input data * removed device from param list * fixed randint range
1 parent c2235af commit 6f44d25

File tree

3 files changed

+72
-204
lines changed

3 files changed

+72
-204
lines changed

benchmarks/benchmark_bitpacking.py

Lines changed: 58 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4
66

77

8-
def benchmark(function, num_runs, setup =None):
9-
args = setup()
8+
def benchmark(function, args, num_runs):
109
torch.cuda.synchronize()
1110
start_event = torch.cuda.Event(enable_timing=True)
1211
end_event = torch.cuda.Event(enable_timing=True)
@@ -21,207 +20,74 @@ def benchmark(function, num_runs, setup =None):
2120

2221

2322
def test_vs_existing():
24-
def new_():
25-
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
23+
def new_(scale):
24+
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
2625
packed = pack(fake_tensor, 4, dim=1)
2726
unpacked = unpack(packed, 4, dim=1)
28-
def old_():
29-
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
27+
def old_(scale):
28+
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
3029
packed = pack_uint4(fake_tensor)
3130
unpacked = unpack_uint4(packed)
32-
new_ = torch.compile(new_, fullgraph=True)
33-
old_ = torch.compile(old_, fullgraph=True)
34-
new_()
35-
old_()
36-
print(f"new: {benchmark(new_, 1000)} ms ")
37-
print(f"old: {benchmark(old_, 1000)} ms")
38-
31+
3932

40-
def test_iso_bitpack():
41-
def load4x(scale=1024):
42-
fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda()
33+
for scale in [256,512, 1024, 2048,4096, 8192]:
34+
new_ = torch.compile(new_, fullgraph=True)
35+
old_ = torch.compile(old_, fullgraph=True)
36+
new_(scale)
37+
old_(scale)
38+
print("scale: ", scale)
39+
print(f"new: {benchmark(new_,[scale], 10)} ms ")
40+
print(f"old: {benchmark(old_,[scale], 10)} ms")
41+
42+
43+
def compare_to_fp16():
44+
class Linear16(torch.nn.Module):
45+
def __init__(self, scale):
46+
super().__init__()
47+
scale += scale % 2
48+
self.l1 = torch.nn.Linear(scale * 2, scale, bias=False,dtype=torch.float16).cuda()
49+
self.l2 = torch.nn.Linear(scale, scale//2, bias=False,dtype=torch.float16).cuda()
50+
51+
def forward(self, x):
52+
return self.l2(self.l1(x))
4353

44-
def load2x(scale=1024):
45-
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()
46-
47-
def loadx(scale=1024):
48-
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
54+
class W4A16_symmetric_weight_only(torch.nn.Module):
55+
def __init__(self, scale):
56+
super().__init__()
57+
assert scale % 4 == 0
58+
self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda()
59+
self.s1 = torch.tensor((scale),dtype=torch.float16).cuda()
60+
self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda()
61+
self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda()
4962

50-
def unpack8to2(scale=1024):
51-
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
52-
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)
5363

54-
def unpack8to4(scale=1024):
55-
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
56-
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
57-
58-
def t8to4wmm(scale=1024):
59-
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
60-
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
64+
def forward(self, x):
65+
w = unpack(self.l1.detach(), 4, output_dtype=torch.float16)
66+
x = x * self.s1
67+
x = x @ w
68+
w = unpack(self.l2.detach(), 4, output_dtype=torch.float16)
69+
x = x * self.s2
70+
x = x @ w
6171

62-
torch._dynamo.config.specialize_int = True
63-
# _unpack_c = torch.compile(_unpack, fullgraph=True)
64-
unpack_c = torch.compile(unpack, fullgraph=True)
65-
66-
scale = [16,64,256,1024,4096]
67-
load4x_times = []
68-
unpack8to2_times = []
69-
load2x_times = []
70-
unpack8to4_times = []
71-
for s in scale:
72-
res = benchmark(load4x, 50, scale=s)
73-
load4x_times.append(res)
74-
print(f"load(1, {4*s},{s}) time: {res} ms")
75-
76-
res=benchmark(unpack8to2, 50, scale=s)
77-
unpack8to2_times.append(res)
78-
print(f"load(1, {s},{s}) unpack uint2 time: {res} ms")
72+
return x
73+
74+
torch._dynamo.config.specialize_int = True
75+
for scale in [256,512, 1024, 2048,4096, 8192]:
76+
a = Linear16(scale)
77+
b = W4A16_symmetric_weight_only(scale)
78+
# a = torch.compile(a, fullgraph=True)
79+
b = torch.compile(b, fullgraph=True)
7980

80-
res = benchmark(load2x, 50, scale=s)
81-
load2x_times.append(res)
82-
print(f"load(1, {2*s},{s}) time: {res} ms")
83-
84-
res = benchmark(unpack8to4, 50, scale=s)
85-
unpack8to4_times.append(res)
86-
print(f"load(1, {s},{s}) unpack uint4 time: {res} ms")
87-
print()
88-
89-
# import matplotlib.pyplot as plt
90-
# plt.plot(scale, load4x_times, label="load(1, 4x, x)")
91-
# plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2")
92-
# plt.plot(scale, load2x_times, label="load(1, 2x, x)")
93-
# plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4")
94-
# plt.xlabel("scale")
95-
# plt.ylabel("time (ms)")
96-
# plt.yscale("log")
97-
# plt.legend()
98-
# plt.savefig("benchmark_bitpacking.png")
99-
100-
101-
def test_vs_hqqpack():
102-
#requires hqq to be installed
103-
import hqq
104-
import hqq.core.quantize as hqq_quantize
105-
HQQLinear = hqq_quantize.HQQLinear
106-
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
107-
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm
108-
109-
BASE_QUANT_CONFIG = {
110-
"optimize": True,
111-
"view_as_float": False,
112-
"nbits": 4,
113-
"bitpack": False,
114-
"axis": 1,
115-
}
81+
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
82+
forward_args = [test_input]
83+
b.forward(test_input)
84+
print("scale: ", scale)
85+
print("fp16 time: ", benchmark(a.forward, forward_args, 100))
86+
print("uint4 time: ", benchmark(b.forward, forward_args, 100))
11687

117-
def mixed_mm(
118-
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
119-
):
120-
qcfg = {
121-
**BASE_QUANT_CONFIG,
122-
**dict(group_size=group_size, axis=axis),
123-
}
124-
M, N, K = shape
125-
126-
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")
127-
128-
quant_config = BaseQuantizeConfig(
129-
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
130-
)
131-
quant_config.update({"weight_quant_params": qcfg})
132-
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
133-
W_q, meta = hqq_linear.W_q, hqq_linear.meta
134-
W_q = W_q.to(dtype=quant_dtype)
135-
W_q = (
136-
W_q.reshape(meta["shape"])
137-
if quant_config["weight_quant_params"]["bitpack"] == False
138-
else W_q
139-
)
140-
W_dq = hqq_linear.dequantize()
141-
142-
scales, zeros = meta["scale"], meta["zero"]
143-
scales = scales.reshape(N, -1)
144-
zeros = zeros.reshape(N, -1)
145-
if pack_fn:
146-
packed_w = pack(W_q.T,4,dim=0,order=False)
147-
else:
148-
packed_w = pack_2xint4(W_q.T)
149-
150-
if transposed:
151-
x = torch.randn(M, N, dtype=dtype, device="cuda")
152-
hqq_out = x @ W_dq
153-
154-
tt_out = triton_mixed_mm(
155-
x,
156-
packed_w,
157-
scales.T,
158-
zeros.T,
159-
transposed=True,
160-
group_size=group_size,
161-
fp8_fast_accum=False,
162-
kernel_type=kernel_type,
163-
)
164-
165-
else:
166-
x = torch.randn(M, K, dtype=dtype, device="cuda")
167-
hqq_out = x @ W_dq.T
168-
169-
tt_out = triton_mixed_mm(
170-
x,
171-
packed_w,
172-
scales.T,
173-
zeros.T,
174-
transposed=False,
175-
group_size=group_size,
176-
fp8_fast_accum=False,
177-
kernel_type=kernel_type,
178-
)
179-
180-
shapes = [
181-
[16, 128, 128],
182-
[16, 4096, 4096],
183-
]
184-
group_sizes = [64, 128]
185-
shape = [16, 128, 128]
186-
group_size = 64
187-
pack = torch.compile(pack, fullgraph=True)
188-
for i in range(2):
189-
shape = shapes[i]
190-
group_size = group_sizes[i]
191-
print("linear layer size: ", shape)
192-
print("group size: ", group_size)
193-
# run once to compile
194-
test_mixed_mm(
195-
shape,
196-
group_size,
197-
1,
198-
torch.float16,
199-
True,
200-
"compute_bound",
201-
torch.uint8,
202-
)
203-
# shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
204-
print("pack time (ms): ", benchmark(test_mixed_mm, 100,
205-
shape,
206-
group_size,
207-
1,
208-
torch.float16,
209-
True,
210-
"compute_bound",
211-
torch.uint8))
21288

213-
print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100,
214-
shape,
215-
group_size,
216-
1,
217-
torch.float16,
218-
True,
219-
"compute_bound", #max autotune doesnt work?
220-
torch.uint8,
221-
pack_fn=False))
222-
print("")
223-
224-
89+
22590
if __name__ == "__main__":
91+
compare_to_fp16()
22692
test_vs_existing()
227-
93+

test/prototype/test_bitpacking.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
99

1010
dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
11-
dimensions = (2, 1, 0)
11+
dimensions = (2, 1, 0, -1)
1212
orders = (True, False)
1313

1414

@@ -41,15 +41,13 @@ def test_CPU(dtype, dim, order):
4141
element_type=element_type,
4242
dim = dim,
4343
order = order,
44-
container_dtype = torch.uint8,
45-
device='cpu')
44+
container_dtype = torch.uint8)
4645
assert(packed.shape[dim] == expected_pack_size)
4746
unpacked = unpack(packed,
4847
element_bit_width,
4948
element_type=element_type,
5049
dim = dim,
51-
order = order,
52-
device='cpu')
50+
order = order)
5351
assert(unpacked.allclose(test_tensor))
5452

5553

torchao/prototype/common/bitpacking.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
def mod_shape(shape, mod, dim):
55
"""changes a select dimension of the input shape to mod"""
6-
return (*shape[:dim], mod, *shape[dim+1:])
6+
a = list(shape)
7+
a[dim] = mod
8+
return tuple(a)
79

810
def unpack(data: torch.Tensor,
911
element_bit_width: int,
1012
element_type: Optional[str] = None,
1113
dim: Optional[int] = 0,
1214
order: Optional[bool] = True,
13-
output_dtype: Optional[torch.dtype] = None,
14-
device: Optional[str] ="cuda") -> torch.Tensor:
15+
output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
1516
"""
1617
Unpacks small dtype elements from a larger dtype.
1718
@@ -27,8 +28,10 @@ def unpack(data: torch.Tensor,
2728
"""
2829
container_size = torch.iinfo(data.dtype).bits
2930
scale = container_size // element_bit_width
30-
31+
device = data.device
32+
3133
unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device)
34+
3235
if element_type == "trinary":
3336
unpacked = unpacked.to(torch.int8) - 1
3437
elif output_dtype is not None:
@@ -59,8 +62,7 @@ def pack(data: torch.Tensor,
5962
dim: Optional[int] = 0,
6063
container_dtype: Optional[torch.dtype] = None,
6164
pad: Optional[bool] = False,
62-
order: Optional[bool] = True,
63-
device: Optional[str] = "cuda") -> torch.Tensor:
65+
order: Optional[bool] = True) -> torch.Tensor:
6466
"""
6567
Packs small dtype elements into a container of a larger dtype.
6668
@@ -93,6 +95,8 @@ def pack(data: torch.Tensor,
9395
if container_dtype is not None:
9496
data = data.to(container_dtype)
9597

98+
device = data.device
99+
96100
container_size = torch.iinfo(data.dtype).bits
97101
scale = container_size // element_bit_width
98102

@@ -117,4 +121,4 @@ def _pack(data, container_size, element_bit_width, scale, dim, order, device) ->
117121
else:
118122
packed |= data[slices] << element_bit_width*i
119123
return packed
120-
124+

0 commit comments

Comments
 (0)