11# torchao.float8
22
3- This is an early version of a library for accelerating training with float8 in native PyTorch
3+ This is a workflow for accelerating training with float8 in native PyTorch
44according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf .
55The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
66and composable with key systems such as autograd, ``` torch.compile ``` and distributed.
77With `` torch.compile `` on, initial results show
8- throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
8+ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.
99
1010:warning : <em >See the [ feature tracker] ( https://github.com/pytorch/ao/issues/556 ) for upcoming features.</em >
1111
12- :warning : <em >Backwards compatibility is not guaranteed at this point. The codebase is in active development and
13- will change rapidly.</em >
12+ :warning : <em >The codebase is stable, but backwards compatibility is not yet guaranteed.</em >
1413
1514# Single GPU User API
1615
@@ -21,97 +20,90 @@ We provide three per-tensor scaling strategies: dynamic, delayed and static. Se
2120This is the most accurate recipe as every tensor is scaled dynamically.
2221
2322``` python
24- from torchao.float8 import (
25- convert_to_float8_training,
26- precompute_float8_dynamic_scale_for_fsdp,
27- )
28-
29- # create model
30- m = Model(... )
23+ import torch
24+ import torch.nn as nn
25+ from torchao.float8 import convert_to_float8_training
26+
27+ # create model and sample input
28+ m = nn.Sequential(
29+ nn.Linear(2048 , 4096 ),
30+ nn.Linear(4096 , 128 ),
31+ ).bfloat16().cuda()
32+ x = torch.randn(4096 , 2048 , device = " cuda" , dtype = torch.bfloat16)
33+ optimizer = torch.optim.SGD(m.parameters(), lr = 0.1 )
3134
3235# optional: filter modules from being eligible for float8 conversion
3336def module_filter_fn (mod : torch.nn.Module, fqn : str ):
34- # don't convert the output module
35- if fqn == " output " :
37+ # don't convert the last module
38+ if fqn == " 1 " :
3639 return False
3740 # don't convert linear modules with weight dimensions not divisible by 16
3841 if isinstance (mod, torch.nn.Linear):
3942 if mod.in_features % 16 != 0 or mod.out_features % 16 != 0 :
4043 return False
4144 return True
4245
43- # convert all `torch.nn.Linear` modules to `Float8Linear`
46+ # convert specified `torch.nn.Linear` modules to `Float8Linear`
4447convert_to_float8_training(m, module_filter_fn = module_filter_fn)
4548
46- # optional: use FSDP
47- model = FSDP(model, use_orig_params = True )
48-
49- # optional: enable torch.compile for improved performance
49+ # enable torch.compile for competitive performance
5050m = torch.compile(m)
5151
5252# toy training loop
53- for _ in range (N_ITER ):
53+ for _ in range (10 ):
5454 optimizer.zero_grad()
5555 y = m(x)
5656 y.sum().backward()
5757 optimizer.step()
58-
59- # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
60- # this method is optional but is highly recommended for performance
61- # it calcuclates scales for all parameters in a single all-reduce
62- precompute_float8_dynamic_scale_for_fsdp(model)
63-
6458```
6559
6660## float8 linear with delayed scaling
6761
6862This is theoretically the most performant recipe as it minimizes memory reads.
6963
7064``` python
65+ import torch
66+ import torch.nn as nn
7167from torchao.float8 import (
7268 convert_to_float8_training,
7369 sync_float8_amax_and_scale_history,
70+ Float8LinearConfig,
7471 ScalingType,
72+ CastConfig,
7573)
7674
77- # create model
78- m = Model(... )
75+ # create model and sample input
76+ m = nn.Sequential(
77+ nn.Linear(2048 , 4096 ),
78+ nn.Linear(4096 , 128 ),
79+ ).bfloat16().cuda()
80+ x = torch.randn(4096 , 2048 , device = " cuda" , dtype = torch.bfloat16)
81+ optimizer = torch.optim.SGD(m.parameters(), lr = 0.1 )
7982
80- # optional: configure for compatibility with FSDP. Note that workarounds
81- # gated with config.enable_amax_init and
82- # config.enable_pre_and_post_forward are needed for
83- # autocast + compile + FSDP + float8 to work
84- from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig
83+ # configure delayed scaling
8584config = Float8LinearConfig(
86- enable_amax_init = False , # only needed for autocast + compile + FSDP + float8 delayed
87- enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
8885 cast_config_input = CastConfig(scaling_type = ScalingType.DELAYED ),
8986 cast_config_weight = CastConfig(scaling_type = ScalingType.DELAYED ),
9087 cast_config_grad_output = CastConfig(scaling_type = ScalingType.DELAYED ),
88+ # enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
89+ # enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
9190)
9291
93- # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
94- # type
95- convert_to_float8_training(
96- m,
97- config = config,
98- )
99-
100- # optional: use FSDP
101- model = FSDP(model, use_orig_params = True )
92+ # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
93+ convert_to_float8_training(m, config = config)
10294
103- # optional: enable torch.compile for improved performance
95+ # enable torch.compile for competitive performance
10496m = torch.compile(m)
10597
10698# toy training loop
107- for _ in range (N_ITER ):
99+ for _ in range (10 ):
108100 optimizer.zero_grad()
109101 y = m(x)
110102 y.sum().backward()
111103
112104 # specific to float8 with delayed scaling: separate step to sync scales/amaxes
113105 # in the future, this may move to a context manager
114- sync_float8_amax_and_scale_history(model )
106+ sync_float8_amax_and_scale_history(m )
115107
116108 optimizer.step()
117109```
0 commit comments