diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 11a8188fd6..e86f93b98d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -118,18 +118,17 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy_for_transformer_block( - job_config: JobConfig, - model: nn.Module, + enable_float8: bool, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.enable_float8_linear: - # TODO(future PR): once float8 configuration supports delayed + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed # scaling, add a check here to enforce supported float8 all-gather # configurations - # TODO(future PR): add the items below to __init__.py of torchao.float8, + # TODO(vkuzo): add the items below to __init__.py of torchao.float8, # and import from there from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -143,7 +142,7 @@ def get_tp_parallel_strategy_for_transformer_block( def pipeline_llama( model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -157,11 +156,11 @@ def pipeline_llama( ) if split_mode == "manual": return pipeline_llama_manual( - model, world_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_config ) elif split_mode == "tracer": return pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_config ) @@ -184,7 +183,7 @@ def _mixed_precision_dtype( def pipeline_llama_manual( whole_model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -198,7 +197,6 @@ def pipeline_llama_manual( The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD parallelism. """ - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -287,7 +285,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal def pipeline_llama_tracer( model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -306,7 +304,6 @@ def pipeline_llama_tracer( "To work around, set mixed_precision_param to float32." ) - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -341,15 +338,12 @@ def pipeline_llama_tracer( def apply_tp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, ): """Apply tensor parallelism.""" - - tp_mesh = world_mesh["tp"] - loss_parallel = parallel_dims.loss_parallel_enabled - # 1. Parallelize the embedding and shard its outputs (which are the first # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim @@ -377,7 +371,7 @@ def apply_tp( rowwise_parallel_weight, colwise_parallel_weight, prepare_module_input, - ) = get_tp_parallel_strategy_for_transformer_block(job_config, model) + ) = get_tp_parallel_strategy_for_transformer_block(enable_float8) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel @@ -416,7 +410,7 @@ def apply_tp( ) # updates expressly for async tensor parallel - if job_config.experimental.enable_async_tensor_parallel: + if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group torch._dynamo.config.cache_size_limit = 10000 @@ -434,16 +428,14 @@ def apply_tp( job_config.training.compile = True logger.info( - f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model" + f"Applied {'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" ) return model -def apply_ac(model: nn.Module, job_config: JobConfig): +def apply_ac(model: nn.Module, ac_config: JobConfig): """Apply activation checkpointing to the model.""" - - ac_config = job_config.activation_checkpoint - for layer_id, transformer_block in model.layers.named_children(): transformer_block = checkpoint_wrapper(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) @@ -452,14 +444,8 @@ def apply_ac(model: nn.Module, job_config: JobConfig): return model -def apply_compile(model: nn.Module, job_config: JobConfig): +def apply_compile(model: nn.Module): """Apply torch.compile to each transformer block.""" - - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." - ) - for layer_id, transformer_block in model.layers.named_children(): # TODO: dynamic shape have some issues so we turn it off for now. # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate @@ -474,25 +460,19 @@ def apply_compile(model: nn.Module, job_config: JobConfig): def apply_fsdp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, ): """ Apply data parallelism to the model. FSDP2 is used here. """ - - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - mp_policy = MixedPrecisionPolicy( - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - ) + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): - if parallel_dims.pp_enabled: + if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False @@ -505,11 +485,9 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard( - model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled - ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - if parallel_dims.pp_enabled: + if pp_enabled: # TODO # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even @@ -526,22 +504,19 @@ def apply_fsdp( def apply_ddp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, ): - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism.") - - if job_config.training.compile: - if job_config.experimental.enable_compiled_autograd: + if enable_compile: + if enable_compiled_autograd: torch._dynamo.config.optimize_ddp = ( "python_reducer_without_compiled_forward" ) else: torch._dynamo.config.optimize_ddp = "ddp_optimizer" - model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") return model @@ -562,18 +537,46 @@ def parallelize_llama( """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims, job_config) + model = apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.training.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) if job_config.activation_checkpoint.mode != "none": - model = apply_ac(model, job_config) + model = apply_ac(model, job_config.activation_checkpoint) if job_config.training.compile: - model = apply_compile(model, job_config) + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." + ) + model = apply_compile(model) if parallel_dims.dp_enabled: if parallel_dims.dp_type == "fsdp": - model = apply_fsdp(model, world_mesh, parallel_dims, job_config) + dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh + assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + + model = apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_reduce + ], + pp_enabled=parallel_dims.pp_enabled, + ) else: - model = apply_ddp(model, world_mesh, parallel_dims, job_config) + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + model = apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) return model diff --git a/train.py b/train.py index eef7401df7..92e29058b7 100644 --- a/train.py +++ b/train.py @@ -137,7 +137,7 @@ def main(job_config: JobConfig): if parallel_dims.pp_enabled: stages, model_parts = models_pipelining_fns[model_name]( - whole_model, world_mesh, parallel_dims, job_config, device, model_config + whole_model, pp_mesh, parallel_dims, job_config, device, model_config ) else: # In 1D/2D cases or PP with simple schedules, model_parts is just one item