Skip to content

Commit 1c488e8

Browse files
authored
make float8 README.md examples standalone (#809)
Updates the two float8 README.md examples (for dynamic and delayed scaling) to be standalone Test plan: copy-paste each code sample and execute it, runs successfully
1 parent 317392d commit 1c488e8

File tree

1 file changed

+38
-46
lines changed

1 file changed

+38
-46
lines changed

torchao/float8/README.md

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# torchao.float8
22

3-
This is an early version of a library for accelerating training with float8 in native PyTorch
3+
This is a workflow for accelerating training with float8 in native PyTorch
44
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
55
The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
66
and composable with key systems such as autograd, ```torch.compile``` and distributed.
77
With ``torch.compile`` on, initial results show
8-
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
8+
throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.
99

1010
:warning: <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features.</em>
1111

12-
:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
13-
will change rapidly.</em>
12+
:warning: <em>The codebase is stable, but backwards compatibility is not yet guaranteed.</em>
1413

1514
# Single GPU User API
1615

@@ -21,97 +20,90 @@ We provide three per-tensor scaling strategies: dynamic, delayed and static. Se
2120
This is the most accurate recipe as every tensor is scaled dynamically.
2221

2322
```python
24-
from torchao.float8 import (
25-
convert_to_float8_training,
26-
precompute_float8_dynamic_scale_for_fsdp,
27-
)
28-
29-
# create model
30-
m = Model(...)
23+
import torch
24+
import torch.nn as nn
25+
from torchao.float8 import convert_to_float8_training
26+
27+
# create model and sample input
28+
m = nn.Sequential(
29+
nn.Linear(2048, 4096),
30+
nn.Linear(4096, 128),
31+
).bfloat16().cuda()
32+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
33+
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
3134

3235
# optional: filter modules from being eligible for float8 conversion
3336
def module_filter_fn(mod: torch.nn.Module, fqn: str):
34-
# don't convert the output module
35-
if fqn == "output":
37+
# don't convert the last module
38+
if fqn == "1":
3639
return False
3740
# don't convert linear modules with weight dimensions not divisible by 16
3841
if isinstance(mod, torch.nn.Linear):
3942
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
4043
return False
4144
return True
4245

43-
# convert all `torch.nn.Linear` modules to `Float8Linear`
46+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
4447
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
4548

46-
# optional: use FSDP
47-
model = FSDP(model, use_orig_params=True)
48-
49-
# optional: enable torch.compile for improved performance
49+
# enable torch.compile for competitive performance
5050
m = torch.compile(m)
5151

5252
# toy training loop
53-
for _ in range(N_ITER):
53+
for _ in range(10):
5454
optimizer.zero_grad()
5555
y = m(x)
5656
y.sum().backward()
5757
optimizer.step()
58-
59-
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
60-
# this method is optional but is highly recommended for performance
61-
# it calcuclates scales for all parameters in a single all-reduce
62-
precompute_float8_dynamic_scale_for_fsdp(model)
63-
6458
```
6559

6660
## float8 linear with delayed scaling
6761

6862
This is theoretically the most performant recipe as it minimizes memory reads.
6963

7064
```python
65+
import torch
66+
import torch.nn as nn
7167
from torchao.float8 import (
7268
convert_to_float8_training,
7369
sync_float8_amax_and_scale_history,
70+
Float8LinearConfig,
7471
ScalingType,
72+
CastConfig,
7573
)
7674

77-
# create model
78-
m = Model(...)
75+
# create model and sample input
76+
m = nn.Sequential(
77+
nn.Linear(2048, 4096),
78+
nn.Linear(4096, 128),
79+
).bfloat16().cuda()
80+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
81+
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
7982

80-
# optional: configure for compatibility with FSDP. Note that workarounds
81-
# gated with config.enable_amax_init and
82-
# config.enable_pre_and_post_forward are needed for
83-
# autocast + compile + FSDP + float8 to work
84-
from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig
83+
# configure delayed scaling
8584
config = Float8LinearConfig(
86-
enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
87-
enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
8885
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
8986
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
9087
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
88+
# enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
89+
# enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
9190
)
9291

93-
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
94-
# type
95-
convert_to_float8_training(
96-
m,
97-
config=config,
98-
)
99-
100-
# optional: use FSDP
101-
model = FSDP(model, use_orig_params=True)
92+
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
93+
convert_to_float8_training(m, config=config)
10294

103-
# optional: enable torch.compile for improved performance
95+
# enable torch.compile for competitive performance
10496
m = torch.compile(m)
10597

10698
# toy training loop
107-
for _ in range(N_ITER):
99+
for _ in range(10):
108100
optimizer.zero_grad()
109101
y = m(x)
110102
y.sum().backward()
111103

112104
# specific to float8 with delayed scaling: separate step to sync scales/amaxes
113105
# in the future, this may move to a context manager
114-
sync_float8_amax_and_scale_history(model)
106+
sync_float8_amax_and_scale_history(m)
115107

116108
optimizer.step()
117109
```

0 commit comments

Comments
 (0)