You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.
ghstack-source-id: 10bbe97
Pull Request resolved: #1541
Copy file name to clipboardExpand all lines: README.md
+23-12Lines changed: 23 additions & 12 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con
54
54
55
55
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
56
56
57
+
## Training
58
+
57
59
### Quantization Aware Training
58
60
59
-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
61
+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).
60
62
61
63
```python
62
-
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
torchao currently supports two QAT schemes for linear layers:
25
-
- int8 per token dynamic activations + int4 per group weights
26
-
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
27
-
28
22
QAT typically involves applying a transformation to your model before and after training.
29
23
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
30
24
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
@@ -34,64 +28,169 @@ Between these two steps, training can proceed exactly as before.
34
28
35
29

36
30
37
-
To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
38
-
training, then apply the convert step after training for inference or generation.
39
-
For example, on a single GPU:
31
+
32
+
## torchao APIs
33
+
34
+
torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_)
35
+
API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API
36
+
allows flexible configuration of quantization settings for both activations and weights,
37
+
while the Quantizer classes each hardcode a specific quantization setting.
38
+
39
+
For example, running QAT on a single GPU:
40
40
41
41
```python
42
42
import torch
43
43
from torchtune.models.llama3 import llama3
44
+
45
+
# Set up smaller version of llama3 to fit in a single GPU
The recommended way to run QAT in torchao is through the `quantize_` API:
73
+
1.**Prepare:** specify how weights and/or activations are to be quantized through
74
+
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
75
+
2.**Convert:** quantize the model using the standard post-training quantization (PTQ)
76
+
functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
77
+
78
+
For example:
79
+
80
+
81
+
```python
82
+
from torchao.quantization import (
83
+
quantize_,
84
+
int8_dynamic_activation_int4_weight,
85
+
)
86
+
from torchao.quantization.qat import (
87
+
FakeQuantizeConfig,
88
+
from_intx_quantization_aware_training,
89
+
intx_quantization_aware_training,
90
+
)
91
+
model = get_model()
92
+
93
+
# prepare: insert fake quantization ops
94
+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
-[Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
179
+
to allow users to run quantized-aware fine-tuning as follows:
88
180
89
181
```
90
182
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
91
183
```
92
184
93
-
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
185
+
torchtune also supports a [QAT + LoRA distributed training recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py)
186
+
that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments.
187
+
You can read more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700):
94
188
189
+
```
190
+
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
191
+
```
192
+
193
+
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
0 commit comments