Skip to content

Commit 62c239c

Browse files
authored
adding kv cache quantization to READMEs (#813)
Summary: see READMEs Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1c488e8 commit 62c239c

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
4848

4949
We also provide a developer facing API so you can implement your own quantization algorithms so please use the excellent [HQQ](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) algorithm as a motivating example.
5050

51+
### KV Cache Quantization
52+
53+
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
54+
55+
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
56+
5157
### Quantization Aware Training
5258

5359
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)

torchao/_models/llama/README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,22 @@ and follow the steps to gain access.
88
Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to
99
download and convert the model weights
1010

11-
once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking
12-
directly using `generate.py`.
11+
once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
12+
directly using `generate.py` or `eval.py`.
13+
14+
## KV Cache Quantization - Memory Efficient Inference
15+
We've added some features to `model.py` compared to the original gpt-fast implementation in order to enable long context length (and necessarily memory efficient) inference. Specifically we've added kv_cache quantization and a linear_causal_mask implementation which are **able to reduce memory usage by 50-60%** at long context lengths.
16+
17+
In practice these features alongside int4 weight only quantization allow us to do Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.**
18+
19+
You can check it out yourself with `generate.py`, these features exist as a proof of concept and technical demonstration of the techniques though we're working to figure out a way to release them in a general way. Until then feel free to copy these features into your own models. The details and a full explanation can be found in this [PR](https://github.com/pytorch/ao/pull/738)
20+
21+
To see how these techniques scale generally we've run `generate.py` with subsets of these features for different context lengths on an A100 GPU. You can find commands to reproduce these numbers in `benchmarks.sh`
22+
23+
| context length (tokens) | normal peak (GB) | kv_quant peak (GB) | kv quant+linear_causal_mask peak (GB) |
24+
|-------------------------|------------------|--------------------|---------------------------------------|
25+
| 8192 | 17.86 | 17.52 | 17.47 |
26+
| 16384 | 19.81 | 18.75 | 18.48 |
27+
| 32768 | 23.83 | 21.72 | 20.64 |
28+
| 65536 | 33.5 | 29.54 | 25.24 |
29+
| 131072 | 59.27 | 52.62 | 34.18 |

torchao/quantization/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ note: Int8 dynamic quantization works best on compute bound models like [SAM](ht
2525

2626
For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores
2727

28-
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.
28+
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo-64`.
2929

3030
## Autoquantization
3131

@@ -233,6 +233,12 @@ change_linear_weights_to_int4_woqtensors(model)
233233

234234
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
235235

236+
### KV Cache Quantization
237+
238+
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
239+
240+
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
241+
236242
## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ
237243

238244
```python

0 commit comments

Comments
 (0)