Skip to content

Commit 7119d0c

Browse files
committed
[BE][2/n] use proper method signatures in parallelize_llama
ghstack-source-id: 17a1ee9 Pull Request resolved: #495
1 parent 3ddce59 commit 7119d0c

File tree

2 files changed

+67
-64
lines changed

2 files changed

+67
-64
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,17 @@ def selective_checkpointing_context_fn():
118118

119119

120120
def get_tp_parallel_strategy_for_transformer_block(
121-
job_config: JobConfig,
122-
model: nn.Module,
121+
enable_float8: bool,
123122
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
124123
"""Get the parallel strategy for the transformer model.
125124
126125
This function handles the special case of using float8 with tensor parallelism.
127126
"""
128-
if job_config.training.enable_float8_linear:
129-
# TODO(future PR): once float8 configuration supports delayed
127+
if enable_float8:
128+
# TODO(vkuzo): once float8 configuration supports delayed
130129
# scaling, add a check here to enforce supported float8 all-gather
131130
# configurations
132-
# TODO(future PR): add the items below to __init__.py of torchao.float8,
131+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8,
133132
# and import from there
134133
from torchao.float8.float8_tensor_parallel import (
135134
Float8ColwiseParallel,
@@ -143,7 +142,7 @@ def get_tp_parallel_strategy_for_transformer_block(
143142

144143
def pipeline_llama(
145144
model: nn.Module,
146-
world_mesh: DeviceMesh,
145+
pp_mesh: DeviceMesh,
147146
parallel_dims: "ParallelDims",
148147
job_config: JobConfig,
149148
device: DeviceType,
@@ -157,11 +156,11 @@ def pipeline_llama(
157156
)
158157
if split_mode == "manual":
159158
return pipeline_llama_manual(
160-
model, world_mesh, parallel_dims, job_config, device, model_config
159+
model, pp_mesh, parallel_dims, job_config, device, model_config
161160
)
162161
elif split_mode == "tracer":
163162
return pipeline_llama_tracer(
164-
model, world_mesh, parallel_dims, job_config, device, model_config
163+
model, pp_mesh, parallel_dims, job_config, device, model_config
165164
)
166165

167166

@@ -184,7 +183,7 @@ def _mixed_precision_dtype(
184183

185184
def pipeline_llama_manual(
186185
whole_model: nn.Module,
187-
world_mesh: DeviceMesh,
186+
pp_mesh: DeviceMesh,
188187
parallel_dims: "ParallelDims",
189188
job_config: JobConfig,
190189
device: DeviceType,
@@ -198,7 +197,6 @@ def pipeline_llama_manual(
198197
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
199198
parallelism.
200199
"""
201-
pp_mesh = world_mesh["pp"]
202200
pp_rank = pp_mesh.get_local_rank()
203201
pp_size = pp_mesh.size()
204202
microbatches = (
@@ -287,7 +285,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
287285

288286
def pipeline_llama_tracer(
289287
model: nn.Module,
290-
world_mesh: DeviceMesh,
288+
pp_mesh: DeviceMesh,
291289
parallel_dims: "ParallelDims",
292290
job_config: JobConfig,
293291
device: DeviceType,
@@ -306,7 +304,6 @@ def pipeline_llama_tracer(
306304
"To work around, set mixed_precision_param to float32."
307305
)
308306

309-
pp_mesh = world_mesh["pp"]
310307
pp_rank = pp_mesh.get_local_rank()
311308
pp_size = pp_mesh.size()
312309
microbatches = (
@@ -341,15 +338,12 @@ def pipeline_llama_tracer(
341338

342339
def apply_tp(
343340
model: nn.Module,
344-
world_mesh: DeviceMesh,
345-
parallel_dims: "ParallelDims",
346-
job_config: JobConfig,
341+
tp_mesh: DeviceMesh,
342+
loss_parallel: bool,
343+
enable_float8: bool,
344+
enable_async_tp: bool,
347345
):
348346
"""Apply tensor parallelism."""
349-
350-
tp_mesh = world_mesh["tp"]
351-
loss_parallel = parallel_dims.loss_parallel_enabled
352-
353347
# 1. Parallelize the embedding and shard its outputs (which are the first
354348
# transformer block's inputs)
355349
# 2. Parallelize the root norm layer over the sequence dim
@@ -377,7 +371,7 @@ def apply_tp(
377371
rowwise_parallel_weight,
378372
colwise_parallel_weight,
379373
prepare_module_input,
380-
) = get_tp_parallel_strategy_for_transformer_block(job_config, model)
374+
) = get_tp_parallel_strategy_for_transformer_block(enable_float8)
381375

382376
# Apply tensor + sequence parallelism to every transformer block
383377
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
@@ -416,7 +410,7 @@ def apply_tp(
416410
)
417411

418412
# updates expressly for async tensor parallel
419-
if job_config.experimental.enable_async_tensor_parallel:
413+
if enable_async_tp:
420414
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
421415

422416
torch._dynamo.config.cache_size_limit = 10000
@@ -434,16 +428,14 @@ def apply_tp(
434428
job_config.training.compile = True
435429

436430
logger.info(
437-
f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model"
431+
f"Applied {'Async ' if enable_async_tp else ''}"
432+
"Tensor Parallelism to the model"
438433
)
439434
return model
440435

441436

442-
def apply_ac(model: nn.Module, job_config: JobConfig):
437+
def apply_ac(model: nn.Module, ac_config: JobConfig):
443438
"""Apply activation checkpointing to the model."""
444-
445-
ac_config = job_config.activation_checkpoint
446-
447439
for layer_id, transformer_block in model.layers.named_children():
448440
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
449441
model.layers.register_module(layer_id, transformer_block)
@@ -452,14 +444,8 @@ def apply_ac(model: nn.Module, job_config: JobConfig):
452444
return model
453445

454446

455-
def apply_compile(model: nn.Module, job_config: JobConfig):
447+
def apply_compile(model: nn.Module):
456448
"""Apply torch.compile to each transformer block."""
457-
458-
if job_config.model.norm_type == "fused_rmsnorm":
459-
raise NotImplementedError(
460-
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
461-
)
462-
463449
for layer_id, transformer_block in model.layers.named_children():
464450
# TODO: dynamic shape have some issues so we turn it off for now.
465451
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
@@ -474,25 +460,19 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
474460

475461
def apply_fsdp(
476462
model: nn.Module,
477-
world_mesh: DeviceMesh,
478-
parallel_dims: "ParallelDims",
479-
job_config: JobConfig,
463+
dp_mesh: DeviceMesh,
464+
param_dtype: torch.dtype,
465+
reduce_dtype: torch.dtype,
466+
pp_enabled: bool,
480467
):
481468
"""
482469
Apply data parallelism to the model. FSDP2 is used here.
483470
"""
484-
485-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
486-
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
487-
488-
mp_policy = MixedPrecisionPolicy(
489-
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
490-
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
491-
)
471+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
492472
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
493473

494474
for layer_id, transformer_block in model.layers.items():
495-
if parallel_dims.pp_enabled:
475+
if pp_enabled:
496476
# For PP, do not reshard after forward to avoid per-microbatch
497477
# all-gathers, which can be expensive and non-overlapped
498478
reshard_after_forward = False
@@ -505,11 +485,9 @@ def apply_fsdp(
505485
**fsdp_config,
506486
reshard_after_forward=reshard_after_forward,
507487
)
508-
fully_shard(
509-
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
510-
)
488+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
511489

512-
if parallel_dims.pp_enabled:
490+
if pp_enabled:
513491
# TODO
514492
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
515493
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
@@ -526,22 +504,19 @@ def apply_fsdp(
526504

527505
def apply_ddp(
528506
model: nn.Module,
529-
world_mesh: DeviceMesh,
530-
parallel_dims: "ParallelDims",
531-
job_config: JobConfig,
507+
dp_mesh: DeviceMesh,
508+
enable_compile: bool,
509+
enable_compiled_autograd: bool,
532510
):
533-
if world_mesh.ndim > 1:
534-
raise RuntimeError("DDP has not supported > 1D parallelism.")
535-
536-
if job_config.training.compile:
537-
if job_config.experimental.enable_compiled_autograd:
511+
if enable_compile:
512+
if enable_compiled_autograd:
538513
torch._dynamo.config.optimize_ddp = (
539514
"python_reducer_without_compiled_forward"
540515
)
541516
else:
542517
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
543518

544-
model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)
519+
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
545520

546521
logger.info("Applied DDP to the model")
547522
return model
@@ -562,18 +537,46 @@ def parallelize_llama(
562537
"""
563538

564539
if parallel_dims.tp_enabled:
565-
model = apply_tp(model, world_mesh, parallel_dims, job_config)
540+
model = apply_tp(
541+
model,
542+
world_mesh["tp"],
543+
loss_parallel=parallel_dims.loss_parallel_enabled,
544+
enable_float8=job_config.training.enable_float8_linear,
545+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
546+
)
566547

567548
if job_config.activation_checkpoint.mode != "none":
568-
model = apply_ac(model, job_config)
549+
model = apply_ac(model, job_config.activation_checkpoint)
569550

570551
if job_config.training.compile:
571-
model = apply_compile(model, job_config)
552+
if job_config.model.norm_type == "fused_rmsnorm":
553+
raise NotImplementedError(
554+
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
555+
)
556+
model = apply_compile(model)
572557

573558
if parallel_dims.dp_enabled:
574559
if parallel_dims.dp_type == "fsdp":
575-
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
560+
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
561+
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
562+
563+
model = apply_fsdp(
564+
model,
565+
dp_mesh,
566+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
567+
reduce_dtype=TORCH_DTYPE_MAP[
568+
job_config.training.mixed_precision_reduce
569+
],
570+
pp_enabled=parallel_dims.pp_enabled,
571+
)
576572
else:
577-
model = apply_ddp(model, world_mesh, parallel_dims, job_config)
573+
if world_mesh.ndim > 1:
574+
raise RuntimeError("DDP has not supported > 1D parallelism.")
575+
model = apply_ddp(
576+
model,
577+
world_mesh,
578+
enable_compile=job_config.training.compile,
579+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
580+
)
578581

579582
return model

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def main(job_config: JobConfig):
137137

138138
if parallel_dims.pp_enabled:
139139
stages, model_parts = models_pipelining_fns[model_name](
140-
whole_model, world_mesh, parallel_dims, job_config, device, model_config
140+
whole_model, pp_mesh, parallel_dims, job_config, device, model_config
141141
)
142142
else:
143143
# In 1D/2D cases or PP with simple schedules, model_parts is just one item

0 commit comments

Comments
 (0)