@@ -230,3 +230,65 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
230230 - float8 rowwise with bf16 all-gather + compile: ` TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh `
231231
232232See the float8 training benchmarking [ guide] ( .torchao/float8/benchmarking/README.md ) for more details.
233+
234+ # E2E training + inference flow
235+
236+ There are two float8 inference quantization strategies that be used after training with float8: 1) weight only, and 2) dynamic activation and weight.
237+
238+ ### Weight only quantization
239+ ``` python
240+ import torch
241+ from torch import nn
242+ from torchao.float8.float8_linear_utils import convert_to_float8_training
243+ from torchao.float8.float8_linear import Float8Linear
244+ from torchao.quantization.quant_api import float8_weight_only, quantize_
245+
246+ # simple example model and input
247+ x = torch.randn(32 , 32 , device = " cuda" )
248+ m = nn.Sequential(nn.Linear(32 , 32 )).cuda()
249+
250+ # train with dynamic float8 training with tensorwise scaling
251+ m = convert_to_float8_training(m)
252+
253+ assert isinstance (m[0 ], Float8Linear), " Module is not a Float8Linear"
254+
255+ # convert to weight only quantization for inference
256+ quantize_(m, float8_weight_only())
257+
258+ # run inference
259+ with torch.inference_mode():
260+ out = m(x)
261+ ```
262+
263+
264+ ### Dynamic activation and weight quantization
265+
266+ ``` python
267+ import torch
268+ from torch import nn
269+
270+ from torchao.float8.float8_linear_utils import convert_to_float8_training
271+ from torchao.float8.float8_linear import Float8Linear
272+ from torchao.quantization.granularity import PerTensor
273+ from torchao.quantization.quant_api import quantize_
274+ from torchao.quantization import (
275+ Float8DynamicActivationFloat8WeightConfig,
276+ )
277+
278+ # simple example model and input
279+ x = torch.randn(32 , 32 , device = " cuda" )
280+ m = nn.Sequential(nn.Linear(32 , 32 )).cuda()
281+
282+ # train with dynamic float8 training with tensorwise scaling
283+ m = convert_to_float8_training(m)
284+
285+ # apply dynamic float8 quantization on both activations and weights for inference
286+ assert isinstance (m[0 ], Float8Linear), " Module is not a Float8Linear"
287+ quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity = PerTensor()))
288+
289+ # run inference
290+ with torch.inference_mode():
291+ out = m(x)
292+ ```
293+
294+ For more float8 inference performance benchmarks, see the inference docs [ here] ( https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend-1 ) .
0 commit comments