-
Notifications
You must be signed in to change notification settings - Fork 569
enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d00b73
68d9f61
4cd5f74
05a4a06
f48a82e
14aabfb
b88aee9
2b4e0c2
ad63aba
71d4dc6
23536e9
bdb0fd0
ef0e843
c294f6a
b58b07b
7df10ae
f674012
7dd788c
faefe27
7aad066
5040c31
cee653e
e164285
22c71ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,7 +338,7 @@ def __init__(self): | |
help="Whether to compile the model", | ||
) | ||
self.parser.add_argument( | ||
"--training.fp8_linear", | ||
"--training.enable_fp8_linear", | ||
action="store_true", | ||
help=""" | ||
If true, swaps `torch.nn.Linear` with `Float8Linear` with | ||
|
@@ -347,6 +347,18 @@ def __init__(self): | |
here: https://github.com/pytorch-labs/float8_experimental | ||
""", | ||
) | ||
self.parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline, let's refactor fp8 configs, e.g. have a dedicated field for enabling fp8 or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think one thing to note is that right now this is a boolean which will swap to the default float8 recipe I think we should brainstorm on an elegant solutions for users to express their desired config here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question. evetually we might have to expose args/kwargs from |
||
"--training.enable_fsdp_fp8_all_gather", | ||
action="store_true", | ||
default=False, | ||
help="Whether enable fp8 all-gather in FSDP", | ||
) | ||
self.parser.add_argument( | ||
"--training.precompute_float8_dynamic_scale_for_fsdp", | ||
action="store_true", | ||
default=False, | ||
help="Whether precompute fp8 scales dynamically for FSDP", | ||
) | ||
self.parser.add_argument( | ||
"--training.gc_freq", | ||
type=int, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,31 +12,58 @@ | |
|
||
# Note: Performance | ||
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance | ||
import contextlib | ||
from typing import Optional | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_linear import TensorScalingType | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.logging_utils import logger | ||
|
||
|
||
def build_fp8_linear(model: nn.Module, job_config: JobConfig): | ||
@contextlib.contextmanager | ||
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): | ||
prev = config.enable_fsdp_fp8_all_gather | ||
torch.distributed.barrier() | ||
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather | ||
try: | ||
yield | ||
finally: | ||
torch.distributed.barrier() | ||
config.enable_fsdp_fp8_all_gather = prev | ||
|
||
|
||
def build_fp8_linear( | ||
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False | ||
): | ||
""" | ||
This function converts the linear layers to `Float8Linear`. Note that today, | ||
only dynamic tensor scaling (the default) is supported. | ||
|
||
This will mutate the model inplace. | ||
""" | ||
use_fp8_linear = job_config.training.fp8_linear | ||
enable_fp8_linear = job_config.training.enable_fp8_linear | ||
enable_fsdp_fp8_all_gather = ( | ||
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled | ||
) | ||
try: | ||
from float8_experimental.float8_linear import Float8Linear | ||
from float8_experimental.float8_linear_utils import ( | ||
swap_linear_with_float8_linear, | ||
) | ||
|
||
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noop Q: do we need this in a context manager to make testing + resetting easier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm. EDIT: I also see you mentioned "make testing + resetting easier", which answered why. so I am not sure if it's a question for me |
||
swap_linear_with_float8_linear( | ||
model, scaling_type_w=TensorScalingType.DYNAMIC | ||
) | ||
logger.info( | ||
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" | ||
) | ||
except ImportError as exc: | ||
raise ImportError( | ||
"float8_experimental is not installed. Please install it to use fp8 linear layers." | ||
) from exc | ||
if use_fp8_linear: | ||
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear | ||
swap_linear_with_float8_linear(model, Float8Linear) | ||
logger.info("Swapped to Float8Linear layers") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,12 +117,24 @@ def selective_checkpointing_context_fn(): | |
|
||
def get_tp_parallel_strategy( | ||
job_config: JobConfig, | ||
model: nn.Module, | ||
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: | ||
"""Get the parallel strategy for the transformer model. | ||
|
||
This function handles the special case of using float8 with tensor parallelism. | ||
""" | ||
if job_config.training.fp8_linear == "dynamic": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
update it in this PR to make TP fp8 all-gather work again EDIT: Will enable TP in CI to prevention after having a new tokenizer with vacab size 2560 |
||
if job_config.training.enable_fp8_linear: | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
|
||
if any( | ||
isinstance(m, Float8Linear) | ||
and m.scaling_type_w is TensorScalingType.DELAYED | ||
for m in model.modules() | ||
): | ||
raise NotImplementedError( | ||
"1D TP fp8 all-gather only supports dynamic scaling" | ||
) | ||
|
||
from float8_experimental.float8_tensor_parallel import ( | ||
Float8ColwiseParallel, | ||
Float8RowwiseParallel, | ||
|
@@ -346,7 +358,7 @@ def apply_tp( | |
rowwise_parallel_weight, | ||
colwise_parallel_weight, | ||
prepare_module_input, | ||
) = get_tp_parallel_strategy(job_config) | ||
) = get_tp_parallel_strategy(job_config, model) | ||
loss_parallel = parallel_dims.loss_parallel_enabled | ||
|
||
# 1. Parallelize the embedding and shard its outputs (which are the first | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import torch | ||
import torch.nn.functional as F | ||
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy I think we should hide this import to the path where The problem here is that for every feature that requires an additional install from other dependency, we should try to hide the import to the path that uses it instead of import it globally, otherwise for users who didn't install the float8_experimental, if they rebase, and it would just fail to train for them. Please submit a follow up PR to fix this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got you. I am moving it from top-level to if-else now #464 thanks for the timely reminder |
||
from torch.distributed import destroy_process_group | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
@@ -216,8 +217,8 @@ def loss_fn(pred, labels): | |
whole_model = model_cls.from_model_args(model_config) | ||
|
||
# apply fp8 linear module swap | ||
if job_config.training.fp8_linear: | ||
build_fp8_linear(whole_model, job_config) | ||
if job_config.training.enable_fp8_linear: | ||
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) | ||
|
||
# log model size | ||
model_param_count = get_num_params(whole_model) | ||
|
@@ -398,6 +399,15 @@ def loss_fn(pred, labels): | |
optimizers.step() | ||
lr_schedulers.step() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment to explain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if ( | ||
job_config.training.enable_fp8_linear | ||
and job_config.training.enable_fsdp_fp8_all_gather | ||
and job_config.training.precompute_float8_dynamic_scale_for_fsdp | ||
): | ||
# calculate float8 dynamic amax/scale for all-parameter for FSDP2 | ||
# it issues a single all-reduce for all parameters at once for better performance | ||
precompute_float8_dynamic_scale_for_fsdp(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a noob question: could you briefly explain what this is doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do you refer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per suggestion, raise error if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noob q: do we eventually want to just put this in fsdp2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has to be done after optimizer step (since parameter values change). Are you suggesting to run this in the root module's pre-forward? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah anywhere between the n-1th optimizer step and the first all-gather in the nth step where fsdp2 has control (if there's any). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. I think one concern is that FSDP is agnostic to the fp8 all-gather. FSDP does not know that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. Somehow I thought fsdp2 was fp8-aware |
||
|
||
losses_since_last_log.append(loss) | ||
|
||
# log metrics | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added followings to CI
need follow ups to enable TP fp8 all-gather in CI: current CI tokenizer has 2556, not divisible by 16) #461