1010from  datetime  import  timedelta 
1111
1212import  torch 
13+ 
14+ from  typing  import  List , Optional , Set 
15+ from  functools  import  partial 
16+ 
17+ from  torch .distributed .device_mesh  import  DeviceMesh 
1318from  torch .distributed .elastic .multiprocessing .errors  import  record 
1419
1520from  torchtitan  import  utils 
2833)
2934from  torchtitan .profiling  import  maybe_enable_memory_snapshot , maybe_enable_profiling 
3035
36+ try :
37+     from  torch .distributed .tensor .experimental  import  context_parallel 
38+ except  ImportError :
39+     print (
40+         f"PyTorch version { torch .__version__ }  
41+         "Context Parallel API. Please update to a newer version." 
42+     )
43+ 
44+ 
45+ def  get_train_context (
46+     enable_loss_parallel : bool ,
47+     enable_compiled_autograd : bool ,
48+     cp_mesh : Optional [DeviceMesh ] =  None ,
49+ ):
50+     if  cp_mesh  is  not None :
51+         context_parallel_ctx  =  partial (context_parallel , mesh = cp_mesh )
3152
32- def  get_train_context (enable_loss_parallel : bool , enable_compiled_autograd : bool ):
3353    @contextlib .contextmanager  
34-     def  context ():
54+     def  context (
55+         cp_buffers : List [torch .Tensor ],
56+         cp_seq_dims : List [int ],
57+         cp_no_restore_buffers : Set [torch .Tensor ],
58+     ):
3559        with  contextlib .ExitStack () as  stack :
3660            if  enable_loss_parallel :
3761                stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
62+ 
3863            if  enable_compiled_autograd :
3964                stack .enter_context (
4065                    torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
4166                )
67+ 
68+             if  cp_mesh  is  not None :
69+                 stack .enter_context (
70+                     context_parallel_ctx (
71+                         buffers = cp_buffers ,
72+                         buffer_seq_dims = cp_seq_dims ,
73+                         no_restore_buffers = cp_no_restore_buffers ,
74+                     )
75+                 )
76+ 
4277            yield 
4378
4479    return  context 
@@ -61,6 +96,7 @@ def main(job_config: JobConfig):
6196    parallel_dims  =  ParallelDims (
6297        dp_shard = job_config .training .data_parallel_shard_degree ,
6398        dp_replicate = job_config .training .data_parallel_replicate_degree ,
99+         cp = job_config .experimental .context_parallel_degree ,
64100        tp = job_config .training .tensor_parallel_degree ,
65101        pp = job_config .experimental .pipeline_parallel_degree ,
66102        world_size = world_size ,
@@ -226,6 +262,7 @@ def loss_fn(pred, labels):
226262    train_context  =  get_train_context (
227263        parallel_dims .loss_parallel_enabled ,
228264        job_config .experimental .enable_compiled_autograd ,
265+         world_mesh ["cp" ] if  parallel_dims .cp_enabled  else  None ,
229266    )
230267
231268    # variables used to keep info for metrics logging 
@@ -259,18 +296,24 @@ def loss_fn(pred, labels):
259296            data_load_start  =  time .perf_counter ()
260297            batch  =  next (data_iterator )
261298            input_ids , labels  =  batch 
262-             ntokens_since_last_log  +=  labels .numel ()
299+             ntokens_since_last_log  +=  labels .numel ()  //   parallel_dims . cp 
263300            data_loading_times .append (time .perf_counter () -  data_load_start )
264301
265302            input_ids  =  input_ids .cuda ()
266303            labels  =  labels .cuda ()
267304            optimizers .zero_grad ()
268305
306+             training_context  =  train_context (
307+                 cp_buffers = [input_ids , labels , model .freqs_cis ],
308+                 cp_seq_dims = [1 , 1 , 0 ],
309+                 cp_no_restore_buffers = {input_ids , labels },
310+             )
311+ 
269312            if  parallel_dims .pp_enabled :
270313                # Pipeline Parallel forward / backward inside step() call 
271314                is_last_stage  =  pp_mesh .get_local_rank () ==  pp_mesh .size () -  1 
272315
273-                 with  train_context () :
316+                 with  training_context :
274317                    if  pp_mesh .get_local_rank () ==  0 :
275318                        pp_schedule .step (input_ids )
276319                    elif  is_last_stage :
@@ -287,7 +330,7 @@ def loss_fn(pred, labels):
287330                )
288331            else :
289332                # Non-PP forward / backward 
290-                 with  train_context () :
333+                 with  training_context :
291334                    pred  =  model (input_ids )
292335                    loss  =  loss_fn (pred , labels )
293336                    # pred.shape=(bs, seq_len, vocab_size) 
0 commit comments