diff --git a/torchao/float8/README.md b/torchao/float8/README.md
index da90ebd3a7..94df824d2e 100644
--- a/torchao/float8/README.md
+++ b/torchao/float8/README.md
@@ -1,16 +1,15 @@
# torchao.float8
-This is an early version of a library for accelerating training with float8 in native PyTorch
+This is a workflow for accelerating training with float8 in native PyTorch
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
and composable with key systems such as autograd, ```torch.compile``` and distributed.
With ``torch.compile`` on, initial results show
-throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
+throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.
:warning: See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features.
-:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and
-will change rapidly.
+:warning: The codebase is stable, but backwards compatibility is not yet guaranteed.
# Single GPU User API
@@ -21,18 +20,22 @@ We provide three per-tensor scaling strategies: dynamic, delayed and static. Se
This is the most accurate recipe as every tensor is scaled dynamically.
```python
-from torchao.float8 import (
- convert_to_float8_training,
- precompute_float8_dynamic_scale_for_fsdp,
-)
-
-# create model
-m = Model(...)
+import torch
+import torch.nn as nn
+from torchao.float8 import convert_to_float8_training
+
+# create model and sample input
+m = nn.Sequential(
+ nn.Linear(2048, 4096),
+ nn.Linear(4096, 128),
+).bfloat16().cuda()
+x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
+optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
- # don't convert the output module
- if fqn == "output":
+ # don't convert the last module
+ if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
@@ -40,27 +43,18 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
return False
return True
-# convert all `torch.nn.Linear` modules to `Float8Linear`
+# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
-# optional: use FSDP
-model = FSDP(model, use_orig_params=True)
-
-# optional: enable torch.compile for improved performance
+# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
-for _ in range(N_ITER):
+for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
-
- # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
- # this method is optional but is highly recommended for performance
- # it calcuclates scales for all parameters in a single all-reduce
- precompute_float8_dynamic_scale_for_fsdp(model)
-
```
## float8 linear with delayed scaling
@@ -68,50 +62,48 @@ for _ in range(N_ITER):
This is theoretically the most performant recipe as it minimizes memory reads.
```python
+import torch
+import torch.nn as nn
from torchao.float8 import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
+ Float8LinearConfig,
ScalingType,
+ CastConfig,
)
-# create model
-m = Model(...)
+# create model and sample input
+m = nn.Sequential(
+ nn.Linear(2048, 4096),
+ nn.Linear(4096, 128),
+).bfloat16().cuda()
+x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
+optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
-# optional: configure for compatibility with FSDP. Note that workarounds
-# gated with config.enable_amax_init and
-# config.enable_pre_and_post_forward are needed for
-# autocast + compile + FSDP + float8 to work
-from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig
+# configure delayed scaling
config = Float8LinearConfig(
- enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
- enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+ # enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
+ # enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
)
-# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
-# type
-convert_to_float8_training(
- m,
- config=config,
-)
-
-# optional: use FSDP
-model = FSDP(model, use_orig_params=True)
+# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
+convert_to_float8_training(m, config=config)
-# optional: enable torch.compile for improved performance
+# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
-for _ in range(N_ITER):
+for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
# specific to float8 with delayed scaling: separate step to sync scales/amaxes
# in the future, this may move to a context manager
- sync_float8_amax_and_scale_history(model)
+ sync_float8_amax_and_scale_history(m)
optimizer.step()
```