Skip to content

Commit d70f7f7

Browse files
document e2e training -> inference flow
1 parent cdced21 commit d70f7f7

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

torchao/float8/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,73 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
230230
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
231231

232232
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
233+
234+
# E2E training + inference flow
235+
236+
There are two float8 inference quantization strategies that be used after training with float8: 1) weight only, and 2) dynamic activation and weight.
237+
238+
### Weight only quantization
239+
```python
240+
import torch
241+
from torch import nn
242+
from torchao.float8.float8_linear_utils import convert_to_float8_training
243+
from torchao.float8.float8_linear import Float8Linear
244+
from torchao.quantization.quant_api import float8_weight_only, quantize_
245+
246+
# simple example model and input
247+
x = torch.randn(32, 32, device="cuda")
248+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
249+
250+
# train with dynamic float8 training with tensorwise scaling
251+
m = convert_to_float8_training(m)
252+
253+
254+
# ... train ...
255+
# ... save/load checkpoint ...
256+
257+
258+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
259+
260+
# convert to weight only quantization for inference
261+
quantize_(m, float8_weight_only())
262+
263+
# run inference
264+
with torch.inference_mode():
265+
out = m(x)
266+
```
267+
268+
269+
### Dynamic activation and weight quantization
270+
271+
```python
272+
import torch
273+
from torch import nn
274+
275+
from torchao.float8.float8_linear_utils import convert_to_float8_training
276+
from torchao.float8.float8_linear import Float8Linear
277+
from torchao.quantization.granularity import PerTensor
278+
from torchao.quantization.quant_api import quantize_
279+
from torchao.quantization import (
280+
Float8DynamicActivationFloat8WeightConfig,
281+
)
282+
283+
# simple example model and input
284+
x = torch.randn(32, 32, device="cuda")
285+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
286+
287+
# train with dynamic float8 training with tensorwise scaling
288+
m = convert_to_float8_training(m)
289+
290+
# ... train ...
291+
# ... save/load checkpoint ...
292+
293+
# apply dynamic float8 quantization on both activations and weights for inference
294+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
295+
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
296+
297+
# run inference
298+
with torch.inference_mode():
299+
out = m(x)
300+
```
301+
302+
For more float8 inference performance benchmarks, see the inference docs [here](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend-1).

0 commit comments

Comments
 (0)