diff --git a/torchao/float8/README.md b/torchao/float8/README.md index da90ebd3a7..94df824d2e 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -1,16 +1,15 @@ # torchao.float8 -This is an early version of a library for accelerating training with float8 in native PyTorch +This is a workflow for accelerating training with float8 in native PyTorch according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling, and composable with key systems such as autograd, ```torch.compile``` and distributed. With ``torch.compile`` on, initial results show -throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. +throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs. :warning: See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features. -:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and -will change rapidly. +:warning: The codebase is stable, but backwards compatibility is not yet guaranteed. # Single GPU User API @@ -21,18 +20,22 @@ We provide three per-tensor scaling strategies: dynamic, delayed and static. Se This is the most accurate recipe as every tensor is scaled dynamically. ```python -from torchao.float8 import ( - convert_to_float8_training, - precompute_float8_dynamic_scale_for_fsdp, -) - -# create model -m = Model(...) +import torch +import torch.nn as nn +from torchao.float8 import convert_to_float8_training + +# create model and sample input +m = nn.Sequential( + nn.Linear(2048, 4096), + nn.Linear(4096, 128), +).bfloat16().cuda() +x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) +optimizer = torch.optim.SGD(m.parameters(), lr=0.1) # optional: filter modules from being eligible for float8 conversion def module_filter_fn(mod: torch.nn.Module, fqn: str): - # don't convert the output module - if fqn == "output": + # don't convert the last module + if fqn == "1": return False # don't convert linear modules with weight dimensions not divisible by 16 if isinstance(mod, torch.nn.Linear): @@ -40,27 +43,18 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str): return False return True -# convert all `torch.nn.Linear` modules to `Float8Linear` +# convert specified `torch.nn.Linear` modules to `Float8Linear` convert_to_float8_training(m, module_filter_fn=module_filter_fn) -# optional: use FSDP -model = FSDP(model, use_orig_params=True) - -# optional: enable torch.compile for improved performance +# enable torch.compile for competitive performance m = torch.compile(m) # toy training loop -for _ in range(N_ITER): +for _ in range(10): optimizer.zero_grad() y = m(x) y.sum().backward() optimizer.step() - - # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on - # this method is optional but is highly recommended for performance - # it calcuclates scales for all parameters in a single all-reduce - precompute_float8_dynamic_scale_for_fsdp(model) - ``` ## float8 linear with delayed scaling @@ -68,50 +62,48 @@ for _ in range(N_ITER): This is theoretically the most performant recipe as it minimizes memory reads. ```python +import torch +import torch.nn as nn from torchao.float8 import ( convert_to_float8_training, sync_float8_amax_and_scale_history, + Float8LinearConfig, ScalingType, + CastConfig, ) -# create model -m = Model(...) +# create model and sample input +m = nn.Sequential( + nn.Linear(2048, 4096), + nn.Linear(4096, 128), +).bfloat16().cuda() +x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) +optimizer = torch.optim.SGD(m.parameters(), lr=0.1) -# optional: configure for compatibility with FSDP. Note that workarounds -# gated with config.enable_amax_init and -# config.enable_pre_and_post_forward are needed for -# autocast + compile + FSDP + float8 to work -from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig +# configure delayed scaling config = Float8LinearConfig( - enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed - enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + # enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed + # enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed ) -# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling -# type -convert_to_float8_training( - m, - config=config, -) - -# optional: use FSDP -model = FSDP(model, use_orig_params=True) +# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior +convert_to_float8_training(m, config=config) -# optional: enable torch.compile for improved performance +# enable torch.compile for competitive performance m = torch.compile(m) # toy training loop -for _ in range(N_ITER): +for _ in range(10): optimizer.zero_grad() y = m(x) y.sum().backward() # specific to float8 with delayed scaling: separate step to sync scales/amaxes # in the future, this may move to a context manager - sync_float8_amax_and_scale_history(model) + sync_float8_amax_and_scale_history(m) optimizer.step() ```