Skip to content

Commit 399f4fc

Browse files
document e2e training -> inference flow
1 parent cdced21 commit 399f4fc

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

torchao/float8/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,65 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
230230
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
231231

232232
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
233+
234+
# E2E training + inference flow
235+
236+
There are two float8 inference quantization strategies that be used after training with float8: 1) weight only, and 2) dynamic activation and weight.
237+
238+
### Weight only quantization
239+
```python
240+
import torch
241+
from torch import nn
242+
from torchao.float8.float8_linear_utils import convert_to_float8_training
243+
from torchao.float8.float8_linear import Float8Linear
244+
from torchao.quantization.quant_api import float8_weight_only, quantize_
245+
246+
# simple example model and input
247+
x = torch.randn(32, 32, device="cuda")
248+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
249+
250+
# train with dynamic float8 training with tensorwise scaling
251+
m = convert_to_float8_training(m)
252+
253+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
254+
255+
# convert to weight only quantization for inference
256+
quantize_(m, float8_weight_only())
257+
258+
# run inference
259+
with torch.inference_mode():
260+
out = m(x)
261+
```
262+
263+
264+
### Dynamic activation and weight quantization
265+
266+
```python
267+
import torch
268+
from torch import nn
269+
270+
from torchao.float8.float8_linear_utils import convert_to_float8_training
271+
from torchao.float8.float8_linear import Float8Linear
272+
from torchao.quantization.granularity import PerTensor
273+
from torchao.quantization.quant_api import quantize_
274+
from torchao.quantization import (
275+
Float8DynamicActivationFloat8WeightConfig,
276+
)
277+
278+
# simple example model and input
279+
x = torch.randn(32, 32, device="cuda")
280+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
281+
282+
# train with dynamic float8 training with tensorwise scaling
283+
m = convert_to_float8_training(m)
284+
285+
# apply dynamic float8 quantization on both activations and weights for inference
286+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
287+
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
288+
289+
# run inference
290+
with torch.inference_mode():
291+
out = m(x)
292+
```
293+
294+
For more float8 inference performance benchmarks, see the inference docs [here](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend-1).

0 commit comments

Comments
 (0)