Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not directly related to this PR: is there a way to share code between estimation.py and train.py? it's quite painful to manually keep them in sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. People are aware of this -- some context here: #425 (comment)

Another step I'm going to do is to have a separate training script for PP, which could make it worse.

maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.float8_linear import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
Expand Down Expand Up @@ -127,8 +124,10 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# a no-op hander if fp8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
float8_handler.convert_to_float8_training(whole_model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand Down Expand Up @@ -184,13 +183,14 @@ def loss_fn(pred, labels):
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
Copy link
Contributor

@weifengpy weifengpy Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because train.py has it but estimation.py is outdated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. @sanketpurandare to double check.

# optimizer step
optimizers.step()
lr_schedulers.step()
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
83 changes: 43 additions & 40 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,46 +348,6 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--training.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.float8_scaling_type_input",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
choices=["dynamic", "delayed"],
)
self.parser.add_argument(
"--training.float8_scaling_type_weight",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--training.float8_scaling_type_grad_output",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down Expand Up @@ -483,6 +443,7 @@ def __init__(self):
0 is the default value.
""",
)

# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand All @@ -500,6 +461,48 @@ def __init__(self):
""",
)

# float8 configs
self.parser.add_argument(
"--float8.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--float8.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--float8.scaling_type_input",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
choices=["dynamic", "delayed"],
)
self.parser.add_argument(
"--float8.scaling_type_weight",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--float8.scaling_type_grad_output",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)

# communications library settings
self.parser.add_argument(
"--comm.init_timeout_seconds",
Expand Down
173 changes: 87 additions & 86 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,127 +12,128 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import functools
from typing import Optional

import torch
import torch.nn as nn
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims


@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def maybe_build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.

This will mutate the model inplace.
"""
enable_float8_linear = job_config.training.enable_float8_linear
if not enable_float8_linear:
return
if not is_sm90_or_later():
warning_once(
logger,
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from torchao.float8 import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
)
class Float8Handler:
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not is_sm90_or_later():
logger.warning(
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use fp8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input)
scaling_type_weight = ScalingType(
job_config.training.float8_scaling_type_weight
parallel_dims.dp_enabled
and parallel_dims.dp_type == "fsdp"
and float8_config.enable_fsdp_float8_all_gather
)
scaling_type_grad_output = ScalingType(
job_config.training.float8_scaling_type_grad_output
)
float8_config = Float8LinearConfig(
scaling_type_input = ScalingType(float8_config.scaling_type_input)
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
enable_pre_and_post_forward=False,
)

self.enabled = True

# for precompute_fp8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)

# for sync_float8_amax_and_scale_history
self.delayed_scaling = (
scaling_type_input == "delayed"
or scaling_type_weight == "delayed"
or scaling_type_grad_output == "delayed"
)
self._sync_float8_amax_and_scale_history = None
self.compile = job_config.training.compile

logger.info("Float8 training active")

def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return

from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=float8_config,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
except ImportError as exc:
raise ImportError(
"torchao is not installed. Please install it to use fp8 linear layers."
) from exc


def maybe_precompute_fp8_dynamic_scale_for_fsdp(
model: nn.Module, job_config: JobConfig
):
if not (
job_config.training.enable_float8_linear
and job_config.training.enable_fsdp_float8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
return
if not is_sm90_or_later():
warning_once(
logger,
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module):
if not self.enabled:
return

if not self.precompute_scale:
return

_sync_float8_amax_and_scale_history = None
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)

def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig):
if not (
job_config.training.enable_float8_linear
and (
job_config.training.float8_scaling_type_input == "delayed"
or job_config.training.float8_scaling_type_weight == "delayed"
or job_config.training.float8_scaling_type_grad_output == "delayed"
)
):
return
def sync_float8_amax_and_scale_history(self, model: nn.Module):
if not self.enabled:
return

from torchao.float8 import sync_float8_amax_and_scale_history
if not self.delayed_scaling:
return

# TODO(future): see if precalculating the modules to sync over is going to
# meaningfully help performance
from torchao.float8 import sync_float8_amax_and_scale_history

global _sync_float8_amax_and_scale_history
if _sync_float8_amax_and_scale_history is None:
if job_config.training.compile:
_sync_float8_amax_and_scale_history = torch.compile(
sync_float8_amax_and_scale_history
)
else:
_sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
# TODO(vkuzo): see if precalculating the modules to sync over is going to
# meaningfully help performance

if self._sync_float8_amax_and_scale_history is None:
if self.compile:
self._sync_float8_amax_and_scale_history = torch.compile(
sync_float8_amax_and_scale_history
)
else:
self._sync_float8_amax_and_scale_history = (
sync_float8_amax_and_scale_history
)

sync_float8_amax_and_scale_history(model)
self._sync_float8_amax_and_scale_history(model)
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def parallelize_llama(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.training.enable_float8_linear,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

Expand Down
Loading