Skip to content

Commit cf24ae1

Browse files
Xia-Weiwenjainapurva
authored andcommitted
SmoothQuant using tensor subclassing (#1030)
* SmoothQuant using tensor subclassing * Update UT * Add SmoothQuant example * Remove duplicate implementation of int_scaled_matmul for CPU * Update example.py * Remove unused code * Implement with LinearActivationQuantizedTensor * Fix load/save * Fix device mismatch in observer * Fix fp16 overflow issue in int_scaled_matmul * Add linear_activation_scale_quantized.py for torch.compile * Quantize act/wei to 7 bit on old CPU platforms * Fix device mismatch * Fix UT failures * Fix UT * Don't use torch._int_mm for CPU now because it may overflow * Remove reduce_range * Refine code * Remove torch.compile from example * Add torch.compile in example * Debug CI failures * Debug CI failures (1) * Debug CI failures (2) * Debug CI failures (3) * Work with torch.compile * Update torchao/kernel/intmm.py * Update readme.md * Update readme.md * Debug CI failures (4) * Reimplement with nested tensor subclassing * Test torch.compile only with PyTorch >= 2.5 * Debug CI failures (5) * Debug CI failures (6) * Debug CI failures (7) * Use MovingAvg observer for activation; Update UT and readme * Revert changes to test_spinquant.py; refine readme * Debug CI failures (8) * Debug CI failures (9) * Fix CI failure * Refactor SmoothQuantObserver * Rename readme.md -> README.md * Rename insert_smooth_quant_observer -> insert_smooth_quant_observer_ to indicate inplace * Fix device mismatch in observer * Fall back to conventional quantization if alpha is None * Update README.md to provide more benchmark data; fix CI * Fix CI failures * Add a comment in affine_quantized_tensor.py
1 parent 1b2c2ae commit cf24ae1

File tree

8 files changed

+797
-3
lines changed

8 files changed

+797
-3
lines changed

test/prototype/test_smoothquant.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from copy import deepcopy
2+
import pytest
3+
import torch
4+
import tempfile
5+
from torchao.quantization import quantize_
6+
from torchao.utils import (
7+
TORCH_VERSION_AT_LEAST_2_2,
8+
TORCH_VERSION_AT_LEAST_2_4,
9+
TORCH_VERSION_AT_LEAST_2_5,
10+
)
11+
from torchao.quantization.utils import (
12+
dynamically_quantize_per_channel,
13+
dequantize_per_channel,
14+
)
15+
from torchao.prototype.smoothquant import (
16+
insert_smooth_quant_observer_,
17+
smooth_quant,
18+
SmoothQuantObservedLinear,
19+
save_smooth_quant_recipe,
20+
load_smooth_quant_recipe
21+
)
22+
23+
class ToyLinearModel(torch.nn.Module):
24+
def __init__(self, m=512, n=256, k=128):
25+
super().__init__()
26+
self.linear1 = torch.nn.Linear(m, n, bias=False)
27+
self.linear2 = torch.nn.Linear(n, k, bias=False)
28+
self.linear3 = torch.nn.Linear(k, 1, bias=False)
29+
30+
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"):
31+
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)]
32+
33+
def forward(self, x):
34+
x = self.linear1(x)
35+
x = self.linear2(x)
36+
x = self.linear3(x)
37+
return x
38+
39+
40+
bias_list = [True, False]
41+
alpha_list = [None, 0.5, 0.75]
42+
quant_mode_list = ["static", "dynamic"]
43+
devices = ["cpu"]
44+
if torch.cuda.is_available():
45+
devices.append("cuda")
46+
idtypes = (torch.float, torch.bfloat16, torch.half)
47+
48+
if TORCH_VERSION_AT_LEAST_2_5:
49+
# This test case will trigger recompilation many times, so set a large cache_size_limit here
50+
torch._dynamo.config.cache_size_limit = 128
51+
52+
@pytest.mark.parametrize("bias", bias_list)
53+
@pytest.mark.parametrize("alpha", alpha_list)
54+
@pytest.mark.parametrize("quant_mode", quant_mode_list)
55+
@pytest.mark.parametrize("device", devices)
56+
@pytest.mark.parametrize("idtype", idtypes)
57+
def test_compute(bias, alpha, quant_mode, device, idtype):
58+
class Linear(torch.nn.Module):
59+
def __init__(self, bias: bool):
60+
super().__init__()
61+
self.fc = torch.nn.Linear(32, 32, bias)
62+
self.fc.weight.data = torch.randn_like(self.fc.weight.data)
63+
64+
def forward(self, x):
65+
return self.fc(x)
66+
67+
m = Linear(bias).eval().to(idtype).to(device)
68+
m_ref = deepcopy(m)
69+
data = torch.randn(2, 32, dtype=idtype, device=device)
70+
71+
# calibrate
72+
insert_smooth_quant_observer_(m, alpha, quant_mode)
73+
m(data)
74+
# quantize
75+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
76+
quantize_(m, smooth_quant(), is_observed_linear)
77+
with torch.inference_mode():
78+
if TORCH_VERSION_AT_LEAST_2_5:
79+
m = torch.compile(m, fullgraph=True)
80+
out = m(data)
81+
82+
# reference
83+
weight = m_ref.fc.weight.data.float()
84+
b = m_ref.fc.bias if bias else None
85+
x_abs_max_per_ic = torch.abs(data).max(dim=0).values
86+
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values
87+
smoothing_factor = 1 if alpha is None else (
88+
torch.pow(x_abs_max_per_ic, alpha) / torch.pow(
89+
w_abs_max_per_ic, 1 - alpha)
90+
)
91+
act = data / smoothing_factor
92+
wei = weight * smoothing_factor
93+
qw, w_scales, w_zps = dynamically_quantize_per_channel(
94+
wei, -127, 127, torch.int8
95+
)
96+
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype)
97+
if quant_mode == "static":
98+
# activation is quantized per-tensor
99+
act_min, act_max = torch.aminmax(act.float())
100+
max_val_pos = torch.max(-act_min, act_max)
101+
act_scale = max_val_pos / 127.0
102+
fq_act = torch.quantize_per_tensor(
103+
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8
104+
).dequantize().to(idtype)
105+
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
106+
else:
107+
# activation is quantized per-row (batch * sequence_length)
108+
qx, x_scales, x_zps = dynamically_quantize_per_channel(
109+
act.float(), -127, 127, torch.int8
110+
)
111+
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype)
112+
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
113+
114+
# BFloat16 and Float16 have larger errors
115+
atol = 0.1 if idtype == torch.float else (
116+
0.2 if idtype == torch.half else 0.3
117+
)
118+
assert torch.allclose(out, out_ref.to(idtype), atol=atol)
119+
120+
121+
@pytest.mark.parametrize("alpha", alpha_list)
122+
@pytest.mark.parametrize("quant_mode", quant_mode_list)
123+
@pytest.mark.parametrize("device", devices)
124+
@pytest.mark.parametrize("idtype", idtypes)
125+
def test_save_load_recipe(alpha, quant_mode, device, idtype):
126+
dataset_size = 20
127+
l1, l2, l3 = 512, 256, 128
128+
original_dtype = idtype
129+
n_calib_examples = 10
130+
sequence_length = 5
131+
132+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
133+
m_save_load = deepcopy(m)
134+
135+
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
136+
calibration_data = dataset[:n_calib_examples]
137+
138+
# calibrate
139+
insert_smooth_quant_observer_(m, alpha, quant_mode)
140+
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
141+
142+
for example in calibration_data:
143+
m(example.to(device))
144+
m_save_load(example.to(device))
145+
146+
with tempfile.NamedTemporaryFile() as fp:
147+
save_path = fp.name
148+
save_smooth_quant_recipe(m_save_load, save_path)
149+
load_smooth_quant_recipe(m_save_load, save_path)
150+
151+
# quantize
152+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
153+
quantize_(m, smooth_quant(), is_observed_linear)
154+
if TORCH_VERSION_AT_LEAST_2_5:
155+
# earlier versions are not compatible
156+
m = torch.compile(m, fullgraph=True)
157+
m_save_load = torch.compile(m_save_load, fullgraph=True)
158+
out_list = [m(data.squeeze(0)) for data in dataset]
159+
out = torch.cat(out_list)
160+
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset]
161+
save_load_out = torch.cat(save_load_out_list)
162+
163+
assert out is not None
164+
assert save_load_out is not None
165+
assert torch.allclose(out, save_load_out)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,6 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
14191419
isinstance(input_tensor, AffineQuantizedTensor) and
14201420
_aqt_is_int8_reduced_range(input_tensor) and
14211421
isinstance(weight_tensor, AffineQuantizedTensor) and
1422-
weight_tensor.is_cuda and
14231422
input_tensor.dtype == weight_tensor.dtype and
14241423
isinstance(input_tensor._layout, PlainLayout) and
14251424
isinstance(weight_tensor._layout, PlainLayout)
@@ -1442,7 +1441,11 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
14421441
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
14431442
w_scales = weight_tensor.tensor_impl.scale
14441443
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
1445-
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
1444+
x_scales_dtype = x_scales.dtype
1445+
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
1446+
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
1447+
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype))
1448+
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
14461449

14471450
y = (y_dot_scaled * w_scales).reshape(
14481451
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]

torchao/kernel/intmm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
3737
"""
3838
# torch.compile path
3939
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
40+
if input.device.type == "cpu":
41+
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
42+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float())
4043
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
4144

4245
# error checking for cublas path
@@ -126,7 +129,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -
126129
"""
127130
M, K = a.shape
128131
K, N = b.shape
129-
assert M == scales1.size(0)
132+
assert M == scales1.size(0) or scales1.numel() == 1
130133
assert 1 == scales1.size(1)
131134
assert scales1.is_contiguous()
132135
scales1 = scales1.expand((M, N))
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SmothQuant quantization
2+
This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438).
3+
4+
In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized.
5+
6+
## Quick start
7+
Run the example code with
8+
```bash
9+
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static>
10+
# An example
11+
python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic
12+
```
13+
To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance.
14+
```bash
15+
TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile
16+
```
17+
To save a quantized model for reuse, specify `--model-save-path`
18+
```bash
19+
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt
20+
```
21+
And load it by `--model-load-path`
22+
```bash
23+
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt
24+
```
25+
26+
27+
## Usage of API
28+
The following APIs are provided:
29+
- insert_smooth_quant_observer_
30+
- smooth_quant
31+
- save_smooth_quant_recipe (advanced)
32+
- load_smooth_quant_recipe (advanced)
33+
34+
`insert_smooth_quant_observer_` inserts observers into the model to be quantized. For example:
35+
```python
36+
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic")
37+
```
38+
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe.
39+
40+
`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
41+
```python
42+
from torchao.prototype.smoothquant import SmoothQuantObservedLinear
43+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
44+
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear)
45+
```
46+
`is_observed_linear` is a filter so that we only quantize observed linear layers.
47+
48+
(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model.
49+
50+
A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed.
51+
52+
To save a recipe, users should insert observers and run calibration first. For example,
53+
```python
54+
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic")
55+
for data in dataset_for_calibration:
56+
model(data)
57+
save_smooth_quant_recipe(model, "./smooth_quant_recipe.json")
58+
```
59+
To load a recipe, users should insert observers first. For example,
60+
```python
61+
insert_smooth_quant_observer_(model)
62+
load_smooth_quant_recipe(model, "./smooth_quant_recipe.json")
63+
```
64+
65+
## Benchmark
66+
Running the example with `torch.compile` on a NVIDIA A10G GPU.
67+
### meta-llama/Llama-2-7b-hf
68+
Perplexity
69+
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* |
70+
|-|-|-|-|-|
71+
| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 |
72+
| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 |
73+
74+
Note*: Conventional quantization without SmoothQuant
75+
76+
### meta-llama/Meta-Llama-3-8B
77+
Perplexity
78+
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* |
79+
|-|-|-|-|-|
80+
| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 |
81+
| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 |
82+
83+
Note*: Conventional quantization without SmoothQuant
84+
85+
### Test method
86+
**Commands**
87+
```bash
88+
# dynamic quant
89+
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=dynamic --compile
90+
# static quant
91+
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=static --compile
92+
```
93+
Use `--alpha` to specify the alpha parameter. Add `--disable-smooth-quant` to run quantization without SmoothQuant.
94+
95+
**Environment**
96+
- AWS g5.12xlarge instance
97+
- torch==2.6.0.dev20241017+cu124
98+
- python==3.12.6
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .api import (
2+
insert_smooth_quant_observer_,
3+
smooth_quant,
4+
save_smooth_quant_recipe,
5+
load_smooth_quant_recipe,
6+
)
7+
from .core import SmoothQuantObservedLinear

0 commit comments

Comments
 (0)