Skip to content

Commit cb13a69

Browse files
committed
Update QAT READMEs using new APIs
Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. ghstack-source-id: 10bbe97 Pull Request resolved: #1541
1 parent 6b65cc7 commit cb13a69

File tree

2 files changed

+167
-57
lines changed

2 files changed

+167
-57
lines changed

README.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con
5454

5555
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
5656

57+
## Training
58+
5759
### Quantization Aware Training
5860

59-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
61+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).
6062

6163
```python
62-
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
63-
64-
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
64+
from torchao.quantization import (
65+
quantize_,
66+
int8_dynamic_activation_int4_weight,
67+
)
68+
from torchao.quantization.qat import (
69+
FakeQuantizeConfig,
70+
from_intx_quantization_aware_training,
71+
intx_quantization_aware_training,
72+
)
6573

66-
# Insert "fake quantize" operations into linear layers.
67-
# These operations simulate quantization numerics
68-
model = qat_quantizer.prepare(model)
74+
# Insert fake quantization
75+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
76+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
77+
quantize_(
78+
my_model,
79+
intx_quantization_aware_training(activation_config, weight_config),
80+
)
6981

70-
# Run Training...
82+
# Run training... (not shown)
7183

72-
# Convert fake quantize to actual quantize operations
73-
model = qat_quantizer.convert(model)
84+
# Convert fake quantization to actual quantized operations
85+
quantize_(my_model, from_intx_quantization_aware_training())
86+
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
7487
```
7588

76-
## Training
77-
7889
### Float8
7990

8091
[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.

torchao/quantization/qat/README.md

Lines changed: 144 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
1919
x_fq = (x_fq - zp) * scale
2020
```
2121

22-
## API
23-
24-
torchao currently supports two QAT schemes for linear layers:
25-
- int8 per token dynamic activations + int4 per group weights
26-
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
27-
2822
QAT typically involves applying a transformation to your model before and after training.
2923
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
3024
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
@@ -34,64 +28,169 @@ Between these two steps, training can proceed exactly as before.
3428

3529
![qat](images/qat_diagram.png)
3630

37-
To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
38-
training, then apply the convert step after training for inference or generation.
39-
For example, on a single GPU:
31+
32+
## torchao APIs
33+
34+
torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_)
35+
API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API
36+
allows flexible configuration of quantization settings for both activations and weights,
37+
while the Quantizer classes each hardcode a specific quantization setting.
38+
39+
For example, running QAT on a single GPU:
4040

4141
```python
4242
import torch
4343
from torchtune.models.llama3 import llama3
44+
45+
# Set up smaller version of llama3 to fit in a single GPU
46+
def get_model():
47+
return llama3(
48+
vocab_size=4096,
49+
num_layers=16,
50+
num_heads=16,
51+
num_kv_heads=4,
52+
embed_dim=2048,
53+
max_seq_len=2048,
54+
).cuda()
55+
56+
# Example training loop
57+
def train_loop(m: torch.nn.Module):
58+
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
59+
loss_fn = torch.nn.CrossEntropyLoss()
60+
for i in range(10):
61+
example = torch.randint(0, 4096, (2, 16)).cuda()
62+
target = torch.randn((2, 16, 4096)).cuda()
63+
output = m(example)
64+
loss = loss_fn(output, target)
65+
loss.backward()
66+
optimizer.step()
67+
optimizer.zero_grad()
68+
```
69+
70+
### quantize_ API (recommended)
71+
72+
The recommended way to run QAT in torchao is through the `quantize_` API:
73+
1. **Prepare:** specify how weights and/or activations are to be quantized through
74+
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
75+
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
76+
functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
77+
78+
For example:
79+
80+
81+
```python
82+
from torchao.quantization import (
83+
quantize_,
84+
int8_dynamic_activation_int4_weight,
85+
)
86+
from torchao.quantization.qat import (
87+
FakeQuantizeConfig,
88+
from_intx_quantization_aware_training,
89+
intx_quantization_aware_training,
90+
)
91+
model = get_model()
92+
93+
# prepare: insert fake quantization ops
94+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
95+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
96+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
97+
quantize_(
98+
model,
99+
intx_quantization_aware_training(activation_config, weight_config),
100+
)
101+
102+
# train
103+
train_loop(model)
104+
105+
# convert: transform fake quantization ops into actual quantized ops
106+
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
107+
# quantized activation and weight tensor subclasses
108+
quantize_(model, from_intx_quantization_aware_training())
109+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))
110+
111+
# inference or generate
112+
```
113+
114+
To fake quantize embedding in addition to linear, you can additionally call
115+
the following with a filter function during the prepare step:
116+
117+
```
118+
quantize_(
119+
m,
120+
intx_quantization_aware_training(weight_config=weight_config),
121+
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
122+
)
123+
```
124+
125+
126+
### Quantizer API (legacy)
127+
128+
Alternatively, torchao provides a few hardcoded quantization settings through
129+
the following Quantizers:
130+
- [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight
131+
- [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
132+
- [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94) (embedding), targeting int4 per-group symmetric weight
133+
134+
For example:
135+
```python
44136
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
137+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
138+
model = get_model()
45139

46-
# Smaller version of llama3 to fit in a single GPU
47-
model = llama3(
48-
vocab_size=4096,
49-
num_layers=16,
50-
num_heads=16,
51-
num_kv_heads=4,
52-
embed_dim=2048,
53-
max_seq_len=2048,
54-
).cuda()
55-
56-
# Quantizer for int8 dynamic per token activations +
57-
# int4 grouped per channel weights, only for linear layers
58-
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
59-
60-
# Insert "fake quantize" operations into linear layers.
61-
# These operations simulate quantization numerics during
62-
# training without performing any dtype casting
140+
# prepare: insert fake quantization ops
141+
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
63142
model = qat_quantizer.prepare(model)
64143

65-
# Standard training loop
66-
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
67-
loss_fn = torch.nn.CrossEntropyLoss()
68-
for i in range(10):
69-
example = torch.randint(0, 4096, (2, 16)).cuda()
70-
target = torch.randn((2, 16, 4096)).cuda()
71-
output = model(example)
72-
loss = loss_fn(output, target)
73-
loss.backward()
74-
optimizer.step()
75-
optimizer.zero_grad()
76-
77-
# Convert fake quantize to actual quantize operations
78-
# The quantized model has the exact same structure as the
79-
# quantized model produced in the corresponding PTQ flow
80-
# through `Int8DynActInt4WeightQuantizer`
144+
# train
145+
train_loop(model)
146+
147+
# convert: transform fake quantization ops into actual quantized ops
148+
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
81149
model = qat_quantizer.convert(model)
82150

83151
# inference or generate
84152
```
85153

86-
Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
87-
and apply quantized-aware fine-tuning as follows:
154+
To use multiple Quantizers in the same model for different layer types,
155+
users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242)
156+
as follows:
157+
158+
```python
159+
from torchao.quantization.qat import (
160+
ComposableQATQuantizer,
161+
Int4WeightOnlyEmbeddingQATQuantizer,
162+
Int8DynActInt4WeightQATQuantizer,
163+
)
164+
165+
quantizer = ComposableQATQuantizer([
166+
Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
167+
Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
168+
])
169+
170+
# prepare + train + convert as before
171+
model = qat_quantizer.prepare(model)
172+
train_loop(model)
173+
model = qat_quantizer.convert(model)
174+
```
175+
176+
## torchtune integration
177+
178+
torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
179+
to allow users to run quantized-aware fine-tuning as follows:
88180

89181
```
90182
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
91183
```
92184

93-
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
185+
torchtune also supports a [QAT + LoRA distributed training recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py)
186+
that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments.
187+
You can read more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700):
94188

189+
```
190+
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
191+
```
192+
193+
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
95194

96195
## Evaluation Results
97196

0 commit comments

Comments
 (0)