1010from  datetime  import  timedelta 
1111
1212import  torch 
13+ 
14+ from  typing  import  List , Optional , Set 
15+ from  functools  import  partial 
16+ 
17+ from  torch .distributed .device_mesh  import  DeviceMesh 
1318from  torch .distributed .elastic .multiprocessing .errors  import  record 
19+ from  torch .nn .attention  import  SDPBackend , sdpa_kernel 
1420
1521from  torchtitan  import  utils 
1622from  torchtitan .checkpoint  import  CheckpointManager , TrainState 
2834)
2935from  torchtitan .profiling  import  maybe_enable_memory_snapshot , maybe_enable_profiling 
3036
37+ try :
38+     from  torch .distributed .tensor .experimental  import  context_parallel 
39+ except  ImportError :
40+     print (
41+         f"PyTorch version { torch .__version__ }  
42+         "Context Parallel API. Please update to a newer version." 
43+     )
44+ 
45+ 
46+ def  get_train_context (
47+     enable_loss_parallel : bool ,
48+     enable_compiled_autograd : bool ,
49+     cp_mesh : Optional [DeviceMesh ] =  None ,
50+ ):
51+     if  cp_mesh  is  not None :
52+         context_parallel_ctx  =  partial (context_parallel , mesh = cp_mesh )
3153
32- def  get_train_context (enable_loss_parallel : bool , enable_compiled_autograd : bool ):
3354    @contextlib .contextmanager  
34-     def  context ():
55+     def  context (
56+         cp_buffers : List [torch .Tensor ],
57+         cp_seq_dims : List [int ],
58+         cp_no_restore_buffers : Set [torch .Tensor ],
59+     ):
3560        with  contextlib .ExitStack () as  stack :
3661            if  enable_loss_parallel :
3762                stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
63+ 
3864            if  enable_compiled_autograd :
3965                stack .enter_context (
4066                    torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
4167                )
68+ 
69+             if  cp_mesh  is  not None :
70+                 # currently we only support these two SDP backends. 
71+                 # TODO (xilunwu): support cuDNN backend 
72+                 stack .enter_context (
73+                     sdpa_kernel ([SDPBackend .FLASH_ATTENTION , SDPBackend .EFFICIENT_ATTENTION ])
74+                 )
75+                 stack .enter_context (
76+                     context_parallel_ctx (
77+                         buffers = cp_buffers ,
78+                         buffer_seq_dims = cp_seq_dims ,
79+                         no_restore_buffers = cp_no_restore_buffers ,
80+                     )
81+                 )
82+ 
4283            yield 
4384
4485    return  context 
@@ -70,6 +111,7 @@ def main(job_config: JobConfig):
70111    parallel_dims  =  ParallelDims (
71112        dp_shard = job_config .training .data_parallel_shard_degree ,
72113        dp_replicate = job_config .training .data_parallel_replicate_degree ,
114+         cp = job_config .experimental .context_parallel_degree ,
73115        tp = job_config .training .tensor_parallel_degree ,
74116        pp = job_config .experimental .pipeline_parallel_degree ,
75117        world_size = world_size ,
@@ -235,6 +277,7 @@ def loss_fn(pred, labels):
235277    train_context  =  get_train_context (
236278        parallel_dims .loss_parallel_enabled ,
237279        job_config .experimental .enable_compiled_autograd ,
280+         world_mesh ["cp" ] if  parallel_dims .cp_enabled  else  None ,
238281    )
239282
240283    # variables used to keep info for metrics logging 
@@ -268,18 +311,24 @@ def loss_fn(pred, labels):
268311            data_load_start  =  time .perf_counter ()
269312            batch  =  next (data_iterator )
270313            input_ids , labels  =  batch 
271-             ntokens_since_last_log  +=  labels .numel ()
314+             ntokens_since_last_log  +=  labels .numel ()  //   parallel_dims . cp 
272315            data_loading_times .append (time .perf_counter () -  data_load_start )
273316
274317            input_ids  =  input_ids .cuda ()
275318            labels  =  labels .cuda ()
276319            optimizers .zero_grad ()
277320
321+             training_context  =  train_context (
322+                 cp_buffers = [input_ids , labels , model .freqs_cis ],
323+                 cp_seq_dims = [1 , 1 , 0 ],
324+                 cp_no_restore_buffers = {input_ids , labels },
325+             )
326+ 
278327            if  parallel_dims .pp_enabled :
279328                # Pipeline Parallel forward / backward inside step() call 
280329                is_last_stage  =  pp_mesh .get_local_rank () ==  pp_mesh .size () -  1 
281330
282-                 with  train_context () :
331+                 with  training_context :
283332                    if  pp_mesh .get_local_rank () ==  0 :
284333                        pp_schedule .step (input_ids )
285334                    elif  is_last_stage :
@@ -296,7 +345,7 @@ def loss_fn(pred, labels):
296345                )
297346            else :
298347                # Non-PP forward / backward 
299-                 with  train_context () :
348+                 with  training_context :
300349                    pred  =  model (input_ids )
301350                    loss  =  loss_fn (pred , labels )
302351                    # pred.shape=(bs, seq_len, vocab_size) 
0 commit comments