diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 496b590a43..9b92400c34 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -13,10 +13,12 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance import contextlib +import functools from typing import Optional import torch import torch.nn as nn +from torch._logging import warning_once from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -36,7 +38,13 @@ def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): config.enable_fsdp_fp8_all_gather = prev -def build_fp8_linear( +@functools.lru_cache(None) +def is_sm90_or_later(): + # Float8 is only supported on H100+ GPUs + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +def maybe_build_fp8_linear( model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False ): """ @@ -46,9 +54,14 @@ def build_fp8_linear( This will mutate the model inplace. """ enable_fp8_linear = job_config.training.enable_fp8_linear - enable_fsdp_fp8_all_gather = ( - job_config.training.enable_fsdp_fp8_all_gather and dp_enabled - ) + if not enable_fp8_linear: + return + if not is_sm90_or_later(): + warning_once( + logger, + "Failed to swap to Float8Linear because SM90 or later is not available", + ) + return try: from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( @@ -56,6 +69,9 @@ def build_fp8_linear( ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_fp8_all_gather = ( + job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + ) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear( model, scaling_type_w=TensorScalingType.DYNAMIC @@ -67,3 +83,23 @@ def build_fp8_linear( raise ImportError( "float8_experimental is not installed. Please install it to use fp8 linear layers." ) from exc + + +def maybe_precompute_fp8_dynamic_scale_for_fsdp( + model: nn.Module, job_config: JobConfig +): + if not ( + job_config.training.enable_fp8_linear + and job_config.training.enable_fsdp_fp8_all_gather + and job_config.training.precompute_float8_dynamic_scale_for_fsdp + ): + return + if not is_sm90_or_later(): + warning_once( + logger, + "Skipped precomputing fp8 scales because SM90 or later is not available", + ) + return + from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + + precompute_float8_dynamic_scale_for_fsdp(model) diff --git a/train.py b/train.py index 2c63e29903..afd1d88872 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,10 @@ from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, create_tokenizer -from torchtitan.float8_linear import build_fp8_linear +from torchtitan.float8_linear import ( + maybe_build_fp8_linear, + maybe_precompute_fp8_dynamic_scale_for_fsdp, +) from torchtitan.logging_utils import init_logger, logger from torchtitan.lr_scheduling import get_lr_schedulers from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger @@ -215,9 +218,8 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # apply fp8 linear module swap - if job_config.training.enable_fp8_linear: - build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + # swap to Float8Linear base on fp8 config + maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size model_param_count = get_num_params(whole_model) @@ -398,18 +400,10 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() - if ( - job_config.training.enable_fp8_linear - and job_config.training.enable_fsdp_fp8_all_gather - and job_config.training.precompute_float8_dynamic_scale_for_fsdp - ): - from float8_experimental.fsdp_utils import ( - precompute_float8_dynamic_scale_for_fsdp, - ) - - # calculate float8 dynamic amax/scale for all-parameter for FSDP2 - # it issues a single all-reduce for all parameters at once for better performance - precompute_float8_dynamic_scale_for_fsdp(model) + # when fp8 config is on, + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config) losses_since_last_log.append(loss)