@@ -118,18 +118,17 @@ def selective_checkpointing_context_fn():
118118
119119
120120def get_tp_parallel_strategy_for_transformer_block (
121- job_config : JobConfig ,
122- model : nn .Module ,
121+ enable_float8 : bool ,
123122) -> Tuple [RowwiseParallel , ColwiseParallel , PrepareModuleInput ]:
124123 """Get the parallel strategy for the transformer model.
125124
126125 This function handles the special case of using float8 with tensor parallelism.
127126 """
128- if job_config . training . enable_float8_linear :
129- # TODO(future PR ): once float8 configuration supports delayed
127+ if enable_float8 :
128+ # TODO(vkuzo ): once float8 configuration supports delayed
130129 # scaling, add a check here to enforce supported float8 all-gather
131130 # configurations
132- # TODO(future PR ): add the items below to __init__.py of torchao.float8,
131+ # TODO(vkuzo ): add the items below to __init__.py of torchao.float8,
133132 # and import from there
134133 from torchao .float8 .float8_tensor_parallel import (
135134 Float8ColwiseParallel ,
@@ -143,7 +142,7 @@ def get_tp_parallel_strategy_for_transformer_block(
143142
144143def pipeline_llama (
145144 model : nn .Module ,
146- world_mesh : DeviceMesh ,
145+ pp_mesh : DeviceMesh ,
147146 parallel_dims : "ParallelDims" ,
148147 job_config : JobConfig ,
149148 device : DeviceType ,
@@ -157,11 +156,11 @@ def pipeline_llama(
157156 )
158157 if split_mode == "manual" :
159158 return pipeline_llama_manual (
160- model , world_mesh , parallel_dims , job_config , device , model_config
159+ model , pp_mesh , parallel_dims , job_config , device , model_config
161160 )
162161 elif split_mode == "tracer" :
163162 return pipeline_llama_tracer (
164- model , world_mesh , parallel_dims , job_config , device , model_config
163+ model , pp_mesh , parallel_dims , job_config , device , model_config
165164 )
166165
167166
@@ -184,7 +183,7 @@ def _mixed_precision_dtype(
184183
185184def pipeline_llama_manual (
186185 whole_model : nn .Module ,
187- world_mesh : DeviceMesh ,
186+ pp_mesh : DeviceMesh ,
188187 parallel_dims : "ParallelDims" ,
189188 job_config : JobConfig ,
190189 device : DeviceType ,
@@ -198,7 +197,6 @@ def pipeline_llama_manual(
198197 The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
199198 parallelism.
200199 """
201- pp_mesh = world_mesh ["pp" ]
202200 pp_rank = pp_mesh .get_local_rank ()
203201 pp_size = pp_mesh .size ()
204202 microbatches = (
@@ -287,7 +285,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
287285
288286def pipeline_llama_tracer (
289287 model : nn .Module ,
290- world_mesh : DeviceMesh ,
288+ pp_mesh : DeviceMesh ,
291289 parallel_dims : "ParallelDims" ,
292290 job_config : JobConfig ,
293291 device : DeviceType ,
@@ -306,7 +304,6 @@ def pipeline_llama_tracer(
306304 "To work around, set mixed_precision_param to float32."
307305 )
308306
309- pp_mesh = world_mesh ["pp" ]
310307 pp_rank = pp_mesh .get_local_rank ()
311308 pp_size = pp_mesh .size ()
312309 microbatches = (
@@ -341,15 +338,12 @@ def pipeline_llama_tracer(
341338
342339def apply_tp (
343340 model : nn .Module ,
344- world_mesh : DeviceMesh ,
345- parallel_dims : "ParallelDims" ,
346- job_config : JobConfig ,
341+ tp_mesh : DeviceMesh ,
342+ loss_parallel : bool ,
343+ enable_float8 : bool ,
344+ enable_async_tp : bool ,
347345):
348346 """Apply tensor parallelism."""
349-
350- tp_mesh = world_mesh ["tp" ]
351- loss_parallel = parallel_dims .loss_parallel_enabled
352-
353347 # 1. Parallelize the embedding and shard its outputs (which are the first
354348 # transformer block's inputs)
355349 # 2. Parallelize the root norm layer over the sequence dim
@@ -377,7 +371,7 @@ def apply_tp(
377371 rowwise_parallel_weight ,
378372 colwise_parallel_weight ,
379373 prepare_module_input ,
380- ) = get_tp_parallel_strategy_for_transformer_block (job_config , model )
374+ ) = get_tp_parallel_strategy_for_transformer_block (enable_float8 )
381375
382376 # Apply tensor + sequence parallelism to every transformer block
383377 # NOTE: At the cost of model code change, we can accelerate Sequence Parallel
@@ -416,7 +410,7 @@ def apply_tp(
416410 )
417411
418412 # updates expressly for async tensor parallel
419- if job_config . experimental . enable_async_tensor_parallel :
413+ if enable_async_tp :
420414 from torch .distributed ._symmetric_memory import enable_symm_mem_for_group
421415
422416 torch ._dynamo .config .cache_size_limit = 10000
@@ -434,16 +428,14 @@ def apply_tp(
434428 job_config .training .compile = True
435429
436430 logger .info (
437- f"Applied{ ' Async ' if job_config .experimental .enable_async_tensor_parallel else ' ' } Tensor Parallelism to the model"
431+ f"Applied { 'Async ' if enable_async_tp else '' } "
432+ "Tensor Parallelism to the model"
438433 )
439434 return model
440435
441436
442- def apply_ac (model : nn .Module , job_config : JobConfig ):
437+ def apply_ac (model : nn .Module , ac_config : JobConfig ):
443438 """Apply activation checkpointing to the model."""
444-
445- ac_config = job_config .activation_checkpoint
446-
447439 for layer_id , transformer_block in model .layers .named_children ():
448440 transformer_block = checkpoint_wrapper (transformer_block , ac_config )
449441 model .layers .register_module (layer_id , transformer_block )
@@ -452,14 +444,8 @@ def apply_ac(model: nn.Module, job_config: JobConfig):
452444 return model
453445
454446
455- def apply_compile (model : nn .Module , job_config : JobConfig ):
447+ def apply_compile (model : nn .Module ):
456448 """Apply torch.compile to each transformer block."""
457-
458- if job_config .model .norm_type == "fused_rmsnorm" :
459- raise NotImplementedError (
460- "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
461- )
462-
463449 for layer_id , transformer_block in model .layers .named_children ():
464450 # TODO: dynamic shape have some issues so we turn it off for now.
465451 # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
@@ -474,25 +460,19 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
474460
475461def apply_fsdp (
476462 model : nn .Module ,
477- world_mesh : DeviceMesh ,
478- parallel_dims : "ParallelDims" ,
479- job_config : JobConfig ,
463+ dp_mesh : DeviceMesh ,
464+ param_dtype : torch .dtype ,
465+ reduce_dtype : torch .dtype ,
466+ pp_enabled : bool ,
480467):
481468 """
482469 Apply data parallelism to the model. FSDP2 is used here.
483470 """
484-
485- dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
486- assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
487-
488- mp_policy = MixedPrecisionPolicy (
489- param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
490- reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ],
491- )
471+ mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
492472 fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
493473
494474 for layer_id , transformer_block in model .layers .items ():
495- if parallel_dims . pp_enabled :
475+ if pp_enabled :
496476 # For PP, do not reshard after forward to avoid per-microbatch
497477 # all-gathers, which can be expensive and non-overlapped
498478 reshard_after_forward = False
@@ -505,11 +485,9 @@ def apply_fsdp(
505485 ** fsdp_config ,
506486 reshard_after_forward = reshard_after_forward ,
507487 )
508- fully_shard (
509- model , ** fsdp_config , reshard_after_forward = not parallel_dims .pp_enabled
510- )
488+ fully_shard (model , ** fsdp_config , reshard_after_forward = not pp_enabled )
511489
512- if parallel_dims . pp_enabled :
490+ if pp_enabled :
513491 # TODO
514492 # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
515493 # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
@@ -526,22 +504,19 @@ def apply_fsdp(
526504
527505def apply_ddp (
528506 model : nn .Module ,
529- world_mesh : DeviceMesh ,
530- parallel_dims : "ParallelDims" ,
531- job_config : JobConfig ,
507+ dp_mesh : DeviceMesh ,
508+ enable_compile : bool ,
509+ enable_compiled_autograd : bool ,
532510):
533- if world_mesh .ndim > 1 :
534- raise RuntimeError ("DDP has not supported > 1D parallelism." )
535-
536- if job_config .training .compile :
537- if job_config .experimental .enable_compiled_autograd :
511+ if enable_compile :
512+ if enable_compiled_autograd :
538513 torch ._dynamo .config .optimize_ddp = (
539514 "python_reducer_without_compiled_forward"
540515 )
541516 else :
542517 torch ._dynamo .config .optimize_ddp = "ddp_optimizer"
543518
544- model = replicate (model , device_mesh = world_mesh , bucket_cap_mb = 100 )
519+ model = replicate (model , device_mesh = dp_mesh , bucket_cap_mb = 100 )
545520
546521 logger .info ("Applied DDP to the model" )
547522 return model
@@ -562,18 +537,46 @@ def parallelize_llama(
562537 """
563538
564539 if parallel_dims .tp_enabled :
565- model = apply_tp (model , world_mesh , parallel_dims , job_config )
540+ model = apply_tp (
541+ model ,
542+ world_mesh ["tp" ],
543+ loss_parallel = parallel_dims .loss_parallel_enabled ,
544+ enable_float8 = job_config .training .enable_float8_linear ,
545+ enable_async_tp = job_config .experimental .enable_async_tensor_parallel ,
546+ )
566547
567548 if job_config .activation_checkpoint .mode != "none" :
568- model = apply_ac (model , job_config )
549+ model = apply_ac (model , job_config . activation_checkpoint )
569550
570551 if job_config .training .compile :
571- model = apply_compile (model , job_config )
552+ if job_config .model .norm_type == "fused_rmsnorm" :
553+ raise NotImplementedError (
554+ "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
555+ )
556+ model = apply_compile (model )
572557
573558 if parallel_dims .dp_enabled :
574559 if parallel_dims .dp_type == "fsdp" :
575- model = apply_fsdp (model , world_mesh , parallel_dims , job_config )
560+ dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
561+ assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
562+
563+ model = apply_fsdp (
564+ model ,
565+ dp_mesh ,
566+ param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
567+ reduce_dtype = TORCH_DTYPE_MAP [
568+ job_config .training .mixed_precision_reduce
569+ ],
570+ pp_enabled = parallel_dims .pp_enabled ,
571+ )
576572 else :
577- model = apply_ddp (model , world_mesh , parallel_dims , job_config )
573+ if world_mesh .ndim > 1 :
574+ raise RuntimeError ("DDP has not supported > 1D parallelism." )
575+ model = apply_ddp (
576+ model ,
577+ world_mesh ,
578+ enable_compile = job_config .training .compile ,
579+ enable_compiled_autograd = job_config .experimental .enable_compiled_autograd ,
580+ )
578581
579582 return model
0 commit comments