@@ -51,7 +51,7 @@ def parallelize_llama(
5151            and  not  job_config .training .compile 
5252        ):
5353            raise  RuntimeError ("Async TP requires --training.compile" )
54-         model   =   apply_tp (
54+         apply_tp (
5555            model ,
5656            world_mesh ["tp" ],
5757            loss_parallel = parallel_dims .loss_parallel_enabled ,
@@ -60,7 +60,7 @@ def parallelize_llama(
6060        )
6161
6262    if  job_config .activation_checkpoint .mode  !=  "none" :
63-         model   =   apply_ac (model , job_config .activation_checkpoint )
63+         apply_ac (model , job_config .activation_checkpoint )
6464
6565    # turn on per-TransformerBlock compile after AC wrapping and before FSDP 
6666    if  job_config .training .compile :
@@ -69,14 +69,14 @@ def parallelize_llama(
6969                "fused_rmsnorm is not compatible with torch.compile yet. " 
7070                "Please use rmsnorm or layernorm." 
7171            )
72-         model   =   apply_compile (model )
72+         apply_compile (model )
7373
7474    if  parallel_dims .dp_enabled :
7575        if  parallel_dims .dp_type  ==  "fsdp" :
7676            dp_mesh  =  world_mesh ["dp" ] if  world_mesh .ndim  >  1  else  world_mesh 
7777            assert  dp_mesh .mesh_dim_names  ==  ("dp" ,), dp_mesh .mesh_dim_names 
7878
79-             model   =   apply_fsdp (
79+             apply_fsdp (
8080                model ,
8181                dp_mesh ,
8282                param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
@@ -88,15 +88,13 @@ def parallelize_llama(
8888        else :
8989            if  world_mesh .ndim  >  1 :
9090                raise  RuntimeError ("DDP has not supported > 1D parallelism" )
91-             model   =   apply_ddp (
91+             apply_ddp (
9292                model ,
9393                world_mesh ,
9494                enable_compile = job_config .training .compile ,
9595                enable_compiled_autograd = job_config .experimental .enable_compiled_autograd ,
9696            )
9797
98-     return  model 
99- 
10098
10199def  apply_tp (
102100    model : nn .Module ,
@@ -110,7 +108,7 @@ def apply_tp(
110108    # transformer block's inputs) 
111109    # 2. Parallelize the root norm layer over the sequence dim 
112110    # 3. Parallelize the final linear output layer 
113-     model   =   parallelize_module (
111+     parallelize_module (
114112        model ,
115113        tp_mesh ,
116114        {
@@ -192,7 +190,6 @@ def apply_tp(
192190        f"Applied { 'Float8 '  if  enable_float8  else  '' } { 'Async '  if  enable_async_tp  else  '' }  " 
193191        "Tensor Parallelism to the model" 
194192    )
195-     return  model 
196193
197194
198195# for selective op activation checkpointing 
@@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config):
273270        model .layers .register_module (layer_id , transformer_block )
274271
275272    logger .info (f"Applied { ac_config .mode }   activation checkpointing to the model" )
276-     return  model 
277273
278274
279275def  apply_compile (model : nn .Module ):
@@ -286,7 +282,6 @@ def apply_compile(model: nn.Module):
286282        model .layers .register_module (layer_id , transformer_block )
287283
288284    logger .info ("Compiling each TransformerBlock with torch.compile" )
289-     return  model 
290285
291286
292287def  apply_fsdp (
@@ -329,8 +324,8 @@ def apply_fsdp(
329324            module ._load_state_dict_pre_hooks .clear ()
330325            assert  len (module ._state_dict_pre_hooks ) <=  1 
331326            module ._state_dict_pre_hooks .clear ()
327+ 
332328    logger .info ("Applied FSDP to the model" )
333-     return  model 
334329
335330
336331def  apply_ddp (
@@ -347,7 +342,6 @@ def apply_ddp(
347342        else :
348343            torch ._dynamo .config .optimize_ddp  =  "ddp_optimizer" 
349344
350-     model   =   replicate (model , device_mesh = dp_mesh , bucket_cap_mb = 100 )
345+     replicate (model , device_mesh = dp_mesh , bucket_cap_mb = 100 )
351346
352347    logger .info ("Applied DDP to the model" )
353-     return  model 
0 commit comments