Skip to content

Commit 87826de

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
intx weight only linear quantizer for mps (#1192)
Summary: Pull Request resolved: #1192 Differential Revision: D65079774
1 parent 6234116 commit 87826de

File tree

2 files changed

+296
-4
lines changed

2 files changed

+296
-4
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
import copy
9+
import itertools
10+
import os
11+
import sys
12+
13+
import torch
14+
import torchao_mps_ops
15+
import unittest
16+
17+
from parameterized import parameterized
18+
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
19+
from torchao.experimental.quant_api import _quantize
20+
21+
22+
class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
23+
BITWIDTHS = range(1, 8)
24+
GROUPSIZES = [32, 64, 128, 256]
25+
26+
# Currently, the quantization code in quant_api.py only supports K values
27+
# multiple of group_size.
28+
# TODO(mcandales): Generalize the code in quant_api.py and add tests to
29+
# cover values of K not multiple of group_size.
30+
def _model_setup(self):
31+
group_size = 32
32+
k0 = 96
33+
k1 = 224
34+
k2 = 160
35+
n = 47
36+
layers = [
37+
torch.nn.Linear(k0, k1, bias=False),
38+
torch.nn.Linear(k1, k2, bias=False),
39+
torch.nn.Linear(k2, n, bias=False),
40+
]
41+
model = torch.nn.Sequential(*layers)
42+
return model, group_size, k0, n
43+
44+
def _quantize_model(self, model, precision, nbit, group_size):
45+
quantizer = UIntxWeightOnlyLinearQuantizer(
46+
device="mps",
47+
precision=precision,
48+
bitwidth=nbit,
49+
groupsize=group_size,
50+
)
51+
quantized_model = copy.deepcopy(model)
52+
quantized_model = quantizer.quantize(quantized_model)
53+
return quantized_model
54+
55+
@parameterized.expand(BITWIDTHS)
56+
def test_export(self, nbit):
57+
model, group_size, k0, n = self._model_setup()
58+
m = 3
59+
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")
60+
61+
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
62+
exported = torch.export.export(quantized_model, (activations,))
63+
64+
for node in exported.graph.nodes:
65+
if node.op == "call_function":
66+
self.assertTrue(
67+
str(node.target)
68+
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
69+
)
70+
71+
@parameterized.expand(BITWIDTHS)
72+
def test_2d_output_device_and_shape(self, nbit):
73+
model, group_size, k0, n = self._model_setup()
74+
m = 3
75+
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")
76+
77+
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
78+
result = quantized_model(activations)
79+
self.assertTrue(result.is_mps)
80+
self.assertTrue(result.shape == (m, n))
81+
82+
@parameterized.expand(BITWIDTHS)
83+
def test_3d_output_device_and_shape(self, nbit):
84+
model, group_size, k0, n = self._model_setup()
85+
leading_shape = (3, 5)
86+
activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps")
87+
88+
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
89+
result = quantized_model(activations)
90+
self.assertTrue(result.is_mps)
91+
self.assertTrue(result.shape == (*leading_shape, n))
92+
93+
@parameterized.expand(itertools.product(BITWIDTHS, GROUPSIZES))
94+
def test_valid_groupsizes(self, nbit, group_size):
95+
k0 = 3 * group_size
96+
k1 = 7 * group_size
97+
n = 47
98+
layers = [
99+
torch.nn.Linear(k0, k1, bias=False),
100+
torch.nn.Linear(k1, n, bias=False),
101+
]
102+
model = torch.nn.Sequential(*layers)
103+
m = 5
104+
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")
105+
106+
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
107+
result = quantized_model(activations)
108+
self.assertTrue(result.is_mps)
109+
self.assertTrue(result.shape == (m, n))
110+
111+
@parameterized.expand(BITWIDTHS)
112+
def test_invalid_groupsizes(self, nbit):
113+
group_size = 16
114+
k0 = 3 * group_size
115+
k1 = 7 * group_size
116+
n = 47
117+
layers = [
118+
torch.nn.Linear(k0, k1, bias=False),
119+
torch.nn.Linear(k1, n, bias=False),
120+
]
121+
model = torch.nn.Sequential(*layers)
122+
123+
with self.assertRaises(ValueError):
124+
self._quantize_model(model, torch.float32, nbit, group_size)
125+
126+
# TODO(mcandales): Consolidate with the reference impl in test_lowbit.py
127+
def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z):
128+
N = W.shape[0]
129+
K = W.shape[1]
130+
W = W.to(torch.float32)
131+
scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
132+
zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
133+
W = scales * W + zeros
134+
return torch.mm(A, W.t())
135+
136+
@parameterized.expand(BITWIDTHS)
137+
def test_accuracy(self, nbit):
138+
group_size = 32
139+
m = 3
140+
n = 7
141+
k = 64
142+
with torch.no_grad():
143+
activations = torch.rand(m, k, dtype=torch.float32, device="mps")
144+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
145+
quantized_model = self._quantize_model(
146+
model, torch.float32, nbit, group_size
147+
)
148+
result = quantized_model(activations)
149+
150+
# Compute expected result
151+
weight_cpu = model[0].weight.data
152+
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize(
153+
weight_cpu, group_size, nbit, True, torch.uint8
154+
)
155+
weight_scales_cpu = weight_scales_cpu.t()
156+
weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu
157+
expected = self._reference_linear_lowbit_quant_weights(
158+
activations.cpu(),
159+
weight_qvals_cpu,
160+
group_size,
161+
weight_scales_cpu,
162+
weight_zeros_cpu,
163+
)
164+
165+
# Compare results
166+
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001)
167+
168+
169+
if __name__ == "__main__":
170+
unittest.main()

torchao/experimental/quant_api.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
logger.addHandler(handler)
2626

2727

28-
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool):
28+
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True):
2929
assert nbit >= 1 and nbit <= 8
30-
qmin = -(1 << (nbit - 1))
31-
qmax = (1 << (nbit - 1)) - 1
30+
if signed:
31+
qmin = -(1 << (nbit - 1))
32+
qmax = (1 << (nbit - 1)) - 1
33+
else:
34+
qmin = 0
35+
qmax = (1 << nbit) - 1
3236

3337
n, k = vals.shape
3438
vals = vals.reshape(-1, group_size)
@@ -51,7 +55,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros:
5155
zero_points=group_zeros,
5256
quant_min=qmin,
5357
quant_max=qmax,
54-
dtype=torch.int8,
58+
dtype=torch.int8 if signed else torch.uint8,
5559
group_size=group_size,
5660
)
5761

@@ -516,3 +520,121 @@ def apply(weight):
516520
)
517521

518522
return _get_linear_subclass_inserter(apply)
523+
524+
525+
class UIntxWeightOnlyQuantizedLinear(nn.Module):
526+
def __init__(
527+
self,
528+
pack_weight_op,
529+
linear_op,
530+
):
531+
super().__init__()
532+
self._pack_weights_op = pack_weight_op
533+
self._linear_op = linear_op
534+
535+
def quantize_and_pack_weights(self, weights, nbit, group_size):
536+
self.nbit = nbit
537+
self.group_size = group_size
538+
539+
weight_qvals, weight_scales, weight_zeros = _quantize(
540+
weights, self.group_size, self.nbit, has_weight_zeros=True, signed=False
541+
)
542+
weight_scales = torch.transpose_copy(weight_scales, 1, 0)
543+
weight_zeros = torch.transpose_copy(weight_zeros, 1, 0)
544+
self.weight_scales = weight_scales
545+
self.weight_zeros = -weight_zeros * weight_scales
546+
547+
self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps")
548+
549+
def forward(self, x):
550+
assert x.dim() >= 2
551+
if x.dim() == 2:
552+
return self._linear_op(
553+
x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
554+
)
555+
556+
lead_shape = x.shape[0:-1]
557+
k = x.shape[-1]
558+
n = self.weight_scales.shape[1]
559+
return self._linear_op(
560+
x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
561+
).reshape(*lead_shape, n)
562+
563+
# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear
564+
def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
565+
group_size = kwargs["group_size"]
566+
nbit = kwargs["nbit"]
567+
568+
assert not isinstance(module, nn.Linear)
569+
assert nbit >= 1 and nbit <= 7
570+
571+
for name, child in module.named_children():
572+
if not isinstance(child, nn.Linear):
573+
_replace_linear_with_quantized_linear_mps(child, kwargs)
574+
else:
575+
assert child.bias is None
576+
qlinear = UIntxWeightOnlyQuantizedLinear(
577+
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
578+
linear_op=getattr(
579+
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
580+
),
581+
)
582+
setattr(module, name, qlinear)
583+
qlinear.quantize_and_pack_weights(
584+
child.weight, nbit, group_size
585+
)
586+
587+
588+
class UIntxWeightOnlyLinearQuantizer:
589+
def __init__(
590+
self,
591+
device,
592+
precision,
593+
*,
594+
bitwidth: Optional[int] = None,
595+
groupsize: Optional[int] = None,
596+
):
597+
if device != "mps":
598+
raise NotImplementedError(
599+
"Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer"
600+
)
601+
else:
602+
self.device = device
603+
604+
if precision not in [torch.float32, torch.float16, torch.bfloat16]:
605+
raise ValueError(
606+
"Only precisions float32, float16 & bfloat16 are supported in UIntxWeightOnlyLinearQuantizer"
607+
)
608+
else:
609+
self.precision = precision
610+
611+
if bitwidth is None:
612+
bitwidth = 4
613+
logger.warning(f"bitwidth not specified, defaulting to {bitwidth}.")
614+
if bitwidth not in range(1, 8):
615+
raise ValueError(
616+
"Only bitwidts 1 to 7 are supported in UIntxWeightOnlyLinearQuantizer"
617+
)
618+
else:
619+
self.bitwidth = bitwidth
620+
621+
if groupsize is None:
622+
groupsize = 128
623+
logger.warning(f"groupsize not specified, defaulting to {groupsize}.")
624+
if groupsize not in [32, 64, 128, 256]:
625+
raise ValueError(
626+
"Only groupsizes 32, 64, 128 & 256 are supported in UIntxWeightOnlyLinearQuantizer"
627+
)
628+
else:
629+
self.groupsize = groupsize
630+
631+
def quantize(self, model: nn.Module) -> nn.Module:
632+
model = model.to(self.device).to(self.precision)
633+
_replace_linear_with_quantized_linear_mps(
634+
model,
635+
kwargs={
636+
"group_size": self.groupsize,
637+
"nbit": self.bitwidth,
638+
},
639+
)
640+
return model

0 commit comments

Comments
 (0)