@@ -230,3 +230,73 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
230230 - float8 rowwise with bf16 all-gather + compile: ` TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh `
231231
232232See the float8 training benchmarking [ guide] ( .torchao/float8/benchmarking/README.md ) for more details.
233+
234+ # E2E training + inference flow
235+
236+ There are two float8 inference quantization strategies that be used after training with float8: 1) weight only, and 2) dynamic activation and weight.
237+
238+ ### Weight only quantization
239+ ``` python
240+ import torch
241+ from torch import nn
242+ from torchao.float8.float8_linear_utils import convert_to_float8_training
243+ from torchao.float8.float8_linear import Float8Linear
244+ from torchao.quantization.quant_api import float8_weight_only, quantize_
245+
246+ # simple example model and input
247+ x = torch.randn(32 , 32 , device = " cuda" )
248+ m = nn.Sequential(nn.Linear(32 , 32 )).cuda()
249+
250+ # train with dynamic float8 training with tensorwise scaling
251+ m = convert_to_float8_training(m)
252+
253+
254+ # ... train ...
255+ # ... save/load checkpoint ...
256+
257+
258+ assert isinstance (m[0 ], Float8Linear), " Module is not a Float8Linear"
259+
260+ # convert to weight only quantization for inference
261+ quantize_(m, float8_weight_only())
262+
263+ # run inference
264+ with torch.inference_mode():
265+ out = m(x)
266+ ```
267+
268+
269+ ### Dynamic activation and weight quantization
270+
271+ ``` python
272+ import torch
273+ from torch import nn
274+
275+ from torchao.float8.float8_linear_utils import convert_to_float8_training
276+ from torchao.float8.float8_linear import Float8Linear
277+ from torchao.quantization.granularity import PerTensor
278+ from torchao.quantization.quant_api import quantize_
279+ from torchao.quantization import (
280+ Float8DynamicActivationFloat8WeightConfig,
281+ )
282+
283+ # simple example model and input
284+ x = torch.randn(32 , 32 , device = " cuda" )
285+ m = nn.Sequential(nn.Linear(32 , 32 )).cuda()
286+
287+ # train with dynamic float8 training with tensorwise scaling
288+ m = convert_to_float8_training(m)
289+
290+ # ... train ...
291+ # ... save/load checkpoint ...
292+
293+ # apply dynamic float8 quantization on both activations and weights for inference
294+ assert isinstance (m[0 ], Float8Linear), " Module is not a Float8Linear"
295+ quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity = PerTensor()))
296+
297+ # run inference
298+ with torch.inference_mode():
299+ out = m(x)
300+ ```
301+
302+ For more float8 inference performance benchmarks, see the inference docs [ here] ( https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend-1 ) .
0 commit comments