Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def loss_fn(pred, labels):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand Down
33 changes: 33 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,39 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
Copy link
Contributor Author

@weifengpy weifengpy Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added followings to CI

  • 1D fsdp original dtype all-gather
  • 1D fsdp fp8 all-gather
  • 1D fsdp fp8 all-gather with precomputed dynamic scales

need follow ups to enable TP fp8 all-gather in CI: current CI tokenizer has 2556, not divisible by 16) #461

  • 1D TP fp8 all-gather
  • 2D FSDP + TP fp8 all-gather

[
[
"--training.enable_fp8_linear",
]
],
"FSDP2 with original dtype",
"fp8_fsdp2_orig_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
]
],
"FSDP2 with fp8 all-gather",
"fsdp2_fp8_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
"--training.precompute_float8_dynamic_scale_for_fsdp",
]
],
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
"fsdp2_fp8_all_gather_precompute_dynamic_scales",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down
14 changes: 13 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self):
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.fp8_linear",
"--training.enable_fp8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear` with
Expand All @@ -347,6 +347,18 @@ def __init__(self):
here: https://github.com/pytorch-labs/float8_experimental
""",
)
self.parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's refactor fp8 configs, e.g. have a dedicated field for enabling fp8 or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed fp8_linear to enable_fp8_linear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one thing to note is that right now this is a boolean which will swap to the default float8 recipe
Dynamic scaling x Tensor wise ScalingGranularity x all tensors involved in the matmul [ input, weight, grad]

I think we should brainstorm on an elegant solutions for users to express their desired config here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. evetually we might have to expose args/kwargs from swap_linear_with_float8_linear for flexibility

"--training.enable_fsdp_fp8_all_gather",
action="store_true",
default=False,
help="Whether enable fp8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute fp8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
41 changes: 34 additions & 7 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,58 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
from typing import Optional

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear import TensorScalingType

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger


def build_fp8_linear(model: nn.Module, job_config: JobConfig):
@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


def build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.

This will mutate the model inplace.
"""
use_fp8_linear = job_config.training.fp8_linear
enable_fp8_linear = job_config.training.enable_fp8_linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
try:
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noop Q: do we need this in a context manager to make testing + resetting easier?

Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. set_enable_fsdp_fp8_all_gather is a context manager right now. do you mean "why" it should be a context manager ?

EDIT: I also see you mentioned "make testing + resetting easier", which answered why. so I am not sure if it's a question for me

swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}"
)
except ImportError as exc:
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc
if use_fp8_linear:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
swap_linear_with_float8_linear(model, Float8Linear)
logger.info("Swapped to Float8Linear layers")
16 changes: 14 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,24 @@ def selective_checkpointing_context_fn():

def get_tp_parallel_strategy(
job_config: JobConfig,
model: nn.Module,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.

This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.fp8_linear == "dynamic":
Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_linear == "dynamic" is outdated after recent unification of dynamic/delayed scaling (Float8Linear)
#436

update it in this PR to make TP fp8 all-gather work again

EDIT: Will enable TP in CI to prevention after having a new tokenizer with vacab size 2560

if job_config.training.enable_fp8_linear:
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

if any(
isinstance(m, Float8Linear)
and m.scaling_type_w is TensorScalingType.DELAYED
for m in model.modules()
):
raise NotImplementedError(
"1D TP fp8 all-gather only supports dynamic scaling"
)

from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
Expand Down Expand Up @@ -346,7 +358,7 @@ def apply_tp(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy(job_config)
) = get_tp_parallel_strategy(job_config, model)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down
14 changes: 12 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
Copy link
Collaborator

@wanchaol wanchaol Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy I think we should hide this import to the path where enable_fp8_allgather path happened?

The problem here is that for every feature that requires an additional install from other dependency, we should try to hide the import to the path that uses it instead of import it globally, otherwise for users who didn't install the float8_experimental, if they rebase, and it would just fail to train for them.

Please submit a follow up PR to fix this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I am moving it from top-level to if-else now #464

thanks for the timely reminder

from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -216,8 +217,8 @@ def loss_fn(pred, labels):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# log model size
model_param_count = get_num_params(whole_model)
Expand Down Expand Up @@ -398,6 +399,15 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to explain precompute_float8_dynamic_scale_for_fsdp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
precompute_float8_dynamic_scale_for_fsdp(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a noob question: could you briefly explain what this is doing?
I wonder since we are already using context functions for FP8, can we have a context and run it in a .step() function here, just like optimizer, lr scheduler, and profiler. This would make the code consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you briefly explain what this is doing

precompute_float8_dynamic_scale_for_fsdp is a for-loop over model.parameters(). it issues a single all-reduce for all parameters, ie abs(max(param)) for param in model.parameters() and save amax/scale as param._precomputed_scale. this speed up the training loop since we do not need to compute amax/scale for each parameters in the training loop

we are already using context functions for FP8

do you refer to set_enable_fsdp_fp8_all_gather ? That's for model intiaitialization where we swap nn.Linear with user-defined float8 linear. precompute_float8_dynamic_scale_for_fsdp is for training loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per suggestion, raise error if use_fp8_linear=False or enable_fsdp_fp8_all_gather =False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: do we eventually want to just put this in fsdp2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be done after optimizer step (since parameter values change). Are you suggesting to run this in the root module's pre-forward?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah anywhere between the n-1th optimizer step and the first all-gather in the nth step where fsdp2 has control (if there's any).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I think one concern is that FSDP is agnostic to the fp8 all-gather. FSDP does not know that the fsdp_pre_all_gather and fsdp_post_all_gather of the Float8Linear.weights are implemented to do fp8 all-gather, so at best, the user still would need to register a module forward pre-hook or something to run this method.

Copy link

@yifuwang yifuwang Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Somehow I thought fsdp2 was fp8-aware


losses_since_last_log.append(loss)

# log metrics
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down