Skip to content

Commit 48485a8

Browse files
committed
[BE][5/n] simply pp vs. non-pp set up
ghstack-source-id: 003bfbf Pull Request resolved: #510
1 parent 6e7a183 commit 48485a8

File tree

6 files changed

+73
-85
lines changed

6 files changed

+73
-85
lines changed

estimation.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,33 +122,25 @@ def loss_fn(pred, labels):
122122
f"Building {model_name} {job_config.model.flavor} with {model_config}"
123123
)
124124
with torch.device("meta"):
125-
whole_model = model_cls.from_model_args(model_config)
125+
model = model_cls.from_model_args(model_config)
126126

127127
# a no-op hander if float8 is not enabled
128128
float8_handler = Float8Handler(job_config, parallel_dims)
129129
# swap to Float8Linear based on float8 configs
130-
float8_handler.convert_to_float8_training(whole_model)
130+
float8_handler.convert_to_float8_training(model)
131131

132132
# apply PT-D DP/TP parallelisms and activation checkpointing
133-
model_parts = [whole_model]
134-
model_parts = [
135-
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
136-
for m in model_parts
137-
]
138-
139-
init_device = "cuda"
140-
for model in model_parts:
141-
model.to_empty(device=init_device)
133+
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
142134

135+
model.to_empty(device="cuda")
143136
if not active_fake_mode():
144-
whole_model.init_weights()
137+
model.init_weights()
138+
model.train()
145139

146140
# build optimizer after applying parallelisms to the model
147-
optimizers = build_optimizers(model_parts, job_config)
141+
optimizers = build_optimizers([model], job_config)
148142
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
149143

150-
for model in model_parts:
151-
model.train()
152144
logger.info(f"Vocab size: {model_config.vocab_size}")
153145
# Create a dummy batch instead of loading from a dataset
154146
batch = (
@@ -165,24 +157,23 @@ def loss_fn(pred, labels):
165157
device="cuda",
166158
),
167159
)
168-
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
160+
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
169161
fsdp_memtracker.track_inputs(batch)
170162

171163
with fsdp_memtracker:
172164
for iter_idx in range(2):
173165
input_ids, labels = batch
174166
# train step
175167
with train_context():
176-
pred = whole_model(input_ids)
168+
pred = model(input_ids)
177169
loss = loss_fn(pred, labels)
178170
del pred
179171
loss.backward()
180172

181173
# clip gradients
182-
for model in model_parts:
183-
torch.nn.utils.clip_grad_norm_(
184-
model.parameters(), job_config.training.max_norm, foreach=True
185-
)
174+
torch.nn.utils.clip_grad_norm_(
175+
model.parameters(), job_config.training.max_norm, foreach=True
176+
)
186177
# sync float8 amaxes and scales
187178
float8_handler.sync_float8_amax_and_scale_history(model)
188179
# optimizer step

torchtitan/parallelisms/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
from torchtitan.parallelisms.parallel_dims import ParallelDims
99
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
1010
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
11-
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
1211

1312

1413
__all__ = [
15-
"build_pipeline_schedule",
1614
"models_parallelize_fns",
1715
"models_pipelining_fns",
1816
"ParallelDims",

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def parallelize_llama(
5151
and not job_config.training.compile
5252
):
5353
raise RuntimeError("Async TP requires --training.compile")
54-
model = apply_tp(
54+
apply_tp(
5555
model,
5656
world_mesh["tp"],
5757
loss_parallel=parallel_dims.loss_parallel_enabled,
@@ -60,7 +60,7 @@ def parallelize_llama(
6060
)
6161

6262
if job_config.activation_checkpoint.mode != "none":
63-
model = apply_ac(model, job_config.activation_checkpoint)
63+
apply_ac(model, job_config.activation_checkpoint)
6464

6565
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
6666
if job_config.training.compile:
@@ -69,14 +69,14 @@ def parallelize_llama(
6969
"fused_rmsnorm is not compatible with torch.compile yet. "
7070
"Please use rmsnorm or layernorm."
7171
)
72-
model = apply_compile(model)
72+
apply_compile(model)
7373

7474
if parallel_dims.dp_enabled:
7575
if parallel_dims.dp_type == "fsdp":
7676
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
7777
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
7878

79-
model = apply_fsdp(
79+
apply_fsdp(
8080
model,
8181
dp_mesh,
8282
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
@@ -88,15 +88,13 @@ def parallelize_llama(
8888
else:
8989
if world_mesh.ndim > 1:
9090
raise RuntimeError("DDP has not supported > 1D parallelism")
91-
model = apply_ddp(
91+
apply_ddp(
9292
model,
9393
world_mesh,
9494
enable_compile=job_config.training.compile,
9595
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
9696
)
9797

98-
return model
99-
10098

10199
def apply_tp(
102100
model: nn.Module,
@@ -110,7 +108,7 @@ def apply_tp(
110108
# transformer block's inputs)
111109
# 2. Parallelize the root norm layer over the sequence dim
112110
# 3. Parallelize the final linear output layer
113-
model = parallelize_module(
111+
parallelize_module(
114112
model,
115113
tp_mesh,
116114
{
@@ -192,7 +190,6 @@ def apply_tp(
192190
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
193191
"Tensor Parallelism to the model"
194192
)
195-
return model
196193

197194

198195
# for selective op activation checkpointing
@@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config):
273270
model.layers.register_module(layer_id, transformer_block)
274271

275272
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
276-
return model
277273

278274

279275
def apply_compile(model: nn.Module):
@@ -286,7 +282,6 @@ def apply_compile(model: nn.Module):
286282
model.layers.register_module(layer_id, transformer_block)
287283

288284
logger.info("Compiling each TransformerBlock with torch.compile")
289-
return model
290285

291286

292287
def apply_fsdp(
@@ -329,8 +324,8 @@ def apply_fsdp(
329324
module._load_state_dict_pre_hooks.clear()
330325
assert len(module._state_dict_pre_hooks) <= 1
331326
module._state_dict_pre_hooks.clear()
327+
332328
logger.info("Applied FSDP to the model")
333-
return model
334329

335330

336331
def apply_ddp(
@@ -347,7 +342,6 @@ def apply_ddp(
347342
else:
348343
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
349344

350-
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
345+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
351346

352347
logger.info("Applied DDP to the model")
353-
return model

torchtitan/parallelisms/pipeline_llama.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file applies the PT-D pipeline parallelism to the Llama model.
88

99
import copy
10-
from typing import Union
10+
from typing import Callable, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -18,7 +18,10 @@
1818
from torchtitan.logging import logger
1919
from torchtitan.models.llama.model import ModelArgs
2020
from torchtitan.parallelisms.parallel_dims import ParallelDims
21-
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
21+
from torchtitan.parallelisms.pipelining_utils import (
22+
build_pipeline_schedule,
23+
stage_ids_this_rank,
24+
)
2225

2326

2427
DeviceType = Union[int, str, torch.device]
@@ -31,6 +34,7 @@ def pipeline_llama(
3134
job_config: JobConfig,
3235
device: DeviceType,
3336
model_config: ModelArgs,
37+
loss_fn: Callable[..., torch.Tensor],
3438
):
3539
split_mode = job_config.experimental.pipeline_parallel_split_mode
3640
valid_split_modes = ("manual", "tracer")
@@ -39,14 +43,18 @@ def pipeline_llama(
3943
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
4044
)
4145
if split_mode == "manual":
42-
return pipeline_llama_manual(
46+
stages, models = pipeline_llama_manual(
4347
model, pp_mesh, parallel_dims, job_config, device, model_config
4448
)
4549
elif split_mode == "tracer":
46-
return pipeline_llama_tracer(
50+
stages, models = pipeline_llama_tracer(
4751
model, pp_mesh, parallel_dims, job_config, device, model_config
4852
)
4953

54+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
55+
56+
return pp_schedule, models
57+
5058

5159
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
5260
"""Get meta tensors with the right input shapes used for tracing"""
@@ -218,4 +226,4 @@ def pipeline_llama_tracer(
218226
group=pp_mesh.get_group(),
219227
)
220228
)
221-
return (stages, models)
229+
return stages, models

torchtitan/parallelisms/pipelining_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchtitan.logging import logger
1515

1616

17-
def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
17+
def build_pipeline_schedule(job_config, stages, loss_fn):
1818
looped_schedule = False
1919

2020
if job_config.experimental.pipeline_parallel_schedule == "1f1b":

0 commit comments

Comments
 (0)