Skip to content

Commit 28aceb1

Browse files
committed
enable Context Parallel
ghstack-source-id: a0832f2 Pull Request resolved: #592
1 parent 36fba84 commit 28aceb1

File tree

8 files changed

+146
-43
lines changed

8 files changed

+146
-43
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
6666
parallel_dims = ParallelDims(
6767
dp_shard=job_config.training.data_parallel_shard_degree,
6868
dp_replicate=job_config.training.data_parallel_replicate_degree,
69+
cp=job_config.experimental.context_parallel_degree,
6970
tp=job_config.training.tensor_parallel_degree,
7071
pp=job_config.experimental.pipeline_parallel_degree,
7172
world_size=world_size,

test_runner.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,29 @@ def build_test_list():
306306
"hsdp+tp",
307307
ngpu=8,
308308
),
309+
OverrideDefinitions(
310+
[
311+
[
312+
"--training.data_parallel_shard_degree=2",
313+
"--experimental.context_parallel_degree=2",
314+
]
315+
],
316+
"FSDP+CP",
317+
"fsdp+cp",
318+
ngpu=4,
319+
),
320+
OverrideDefinitions(
321+
[
322+
[
323+
"--training.data_parallel_shard_degree=2",
324+
"--training.data_parallel_replicate_degree=2",
325+
"--experimental.context_parallel_degree=2",
326+
]
327+
],
328+
"HSDP+CP",
329+
"hsdp+cp",
330+
ngpu=8,
331+
),
309332
OverrideDefinitions(
310333
[
311334
[

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def __init__(self):
325325
action="store_true",
326326
help="Enable CompiledAutograd to compile the backward.",
327327
)
328+
self.parser.add_argument(
329+
"--experimental.context_parallel_degree",
330+
type=int,
331+
default=1,
332+
help="Context parallelism degree. 1 means disabled.",
333+
)
328334
self.parser.add_argument(
329335
"--training.mixed_precision_param",
330336
type=str,

torchtitan/models/llama/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
415415
return precompute_freqs_cis(
416416
self.model_args.dim // self.model_args.n_heads,
417417
# Need to compute until at least the max token limit for generation
418-
# (use 2x max sequence length to be safe)
419-
self.model_args.max_seq_len * 2,
418+
# Note: removed the 2x relaxing in CP enablement
419+
self.model_args.max_seq_len,
420420
self.model_args.rope_theta,
421421
)
422422

torchtitan/parallelisms/parallel_dims.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class ParallelDims:
1616
dp_replicate: int
1717
dp_shard: int
18+
cp: int
1819
tp: int
1920
pp: int
2021
world_size: int
@@ -24,36 +25,38 @@ def __post_init__(self):
2425
self._validate()
2526

2627
def _validate(self):
27-
dp_replicate, dp_shard, tp, pp = (
28+
dp_replicate, dp_shard, cp, tp, pp = (
2829
self.dp_replicate,
2930
self.dp_shard,
31+
self.cp,
3032
self.tp,
3133
self.pp,
3234
)
33-
for d in (dp_replicate, tp, pp):
35+
for d in (dp_replicate, cp, tp, pp):
3436
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
3537
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
3638

3739
dp = dp_replicate * dp_shard
3840
if dp < 0:
39-
dp = self.world_size // (tp * pp)
41+
dp = self.world_size // (cp * tp * pp)
4042
self.dp_shard = dp_shard = dp // dp_replicate
4143

4244
assert dp_replicate >= 1
4345
assert dp_shard >= 1
46+
assert cp >= 1, cp
4447
assert tp >= 1, tp
4548
assert pp >= 1, pp
46-
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
49+
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
4750
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
48-
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
51+
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
4952
)
5053

5154
def build_mesh(self, device_type):
5255
dims = []
5356
names = []
5457
for d, name in zip(
55-
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
56-
["pp", "dp_replicate", "dp_shard", "tp"],
58+
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
59+
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5760
):
5861
if d > 1:
5962
dims.append(d)
@@ -71,6 +74,13 @@ def build_mesh(self, device_type):
7174
# initialized
7275
if self.dp_replicate > 1 and self.dp_shard > 1:
7376
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
77+
78+
if self.cp > 1:
79+
if self.dp_replicate > 1 and self.dp_shard > 1:
80+
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
81+
else:
82+
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
83+
7484
return mesh
7585

7686
@property
@@ -85,6 +95,10 @@ def dp_replicate_enabled(self):
8595
def dp_shard_enabled(self):
8696
return self.dp_shard > 1
8797

98+
@property
99+
def cp_enabled(self):
100+
return self.cp > 1
101+
88102
@property
89103
def tp_enabled(self):
90104
return self.tp > 1

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
1415
from torch.distributed import DeviceMesh
1516
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1617
from torch.distributed._composable.replicate import replicate
@@ -72,36 +73,44 @@ def parallelize_llama(
7273
)
7374
apply_compile(model)
7475

75-
if parallel_dims.dp_enabled:
76-
if parallel_dims.dp_shard_enabled:
77-
if parallel_dims.dp_replicate_enabled:
78-
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
79-
else:
80-
dp_mesh = world_mesh["dp"]
81-
82-
apply_fsdp(
83-
model,
84-
dp_mesh,
85-
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
86-
reduce_dtype=TORCH_DTYPE_MAP[
87-
job_config.training.mixed_precision_reduce
88-
],
89-
tp_enabled=parallel_dims.tp_enabled,
90-
pp_enabled=parallel_dims.pp_enabled,
91-
)
92-
if parallel_dims.dp_replicate_enabled:
93-
logger.info("Applied HSDP to the model")
94-
else:
95-
logger.info("Applied FSDP to the model")
76+
if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP, potentially with Context Parallel
77+
dp_mesh_dim_names = (
78+
("dp_replicate", "dp_shard")
79+
if parallel_dims.dp_replicate_enabled
80+
else ("dp",)
81+
)
82+
dp_mesh = (
83+
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
84+
if parallel_dims.cp_enabled
85+
else world_mesh[dp_mesh_dim_names]
86+
)
87+
apply_fsdp(
88+
model,
89+
dp_mesh,
90+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
91+
reduce_dtype=TORCH_DTYPE_MAP[
92+
job_config.training.mixed_precision_reduce
93+
],
94+
tp_enabled=parallel_dims.tp_enabled,
95+
pp_enabled=parallel_dims.pp_enabled,
96+
)
97+
98+
if parallel_dims.dp_replicate_enabled:
99+
logger.info("Applied HSDP to the model")
96100
else:
97-
if world_mesh.ndim > 1:
98-
raise RuntimeError("DDP has not supported > 1D parallelism")
99-
apply_ddp(
100-
model,
101-
world_mesh,
102-
enable_compile=job_config.training.compile,
103-
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
104-
)
101+
logger.info("Applied FSDP to the model")
102+
103+
if parallel_dims.cp_enabled:
104+
logger.info("Applied Context Parallel to the model")
105+
elif parallel_dims.dp_replicate_enabled:
106+
if world_mesh.ndim > 1:
107+
raise RuntimeError("DDP has not supported > 1D parallelism")
108+
apply_ddp(
109+
model,
110+
world_mesh,
111+
enable_compile=job_config.training.compile,
112+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
113+
)
105114

106115

107116
def apply_tp(

train.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from datetime import timedelta
1111

1212
import torch
13+
14+
from typing import List, Optional, Set
15+
from functools import partial
16+
17+
from torch.distributed.device_mesh import DeviceMesh
1318
from torch.distributed.elastic.multiprocessing.errors import record
19+
from torch.nn.attention import SDPBackend, sdpa_kernel
1420

1521
from torchtitan import utils
1622
from torchtitan.checkpoint import CheckpointManager, TrainState
@@ -28,17 +34,52 @@
2834
)
2935
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
3036

37+
try:
38+
from torch.distributed.tensor.experimental import context_parallel
39+
except ImportError:
40+
print(
41+
f"PyTorch version {torch.__version__} does not include the experimental "
42+
"Context Parallel API. Please update to a newer version."
43+
)
44+
45+
46+
def get_train_context(
47+
enable_loss_parallel: bool,
48+
enable_compiled_autograd: bool,
49+
cp_mesh: Optional[DeviceMesh] = None,
50+
):
51+
if cp_mesh is not None:
52+
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)
3153

32-
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
3354
@contextlib.contextmanager
34-
def context():
55+
def context(
56+
cp_buffers: List[torch.Tensor],
57+
cp_seq_dims: List[int],
58+
cp_no_restore_buffers: Set[torch.Tensor],
59+
):
3560
with contextlib.ExitStack() as stack:
3661
if enable_loss_parallel:
3762
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
63+
3864
if enable_compiled_autograd:
3965
stack.enter_context(
4066
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
4167
)
68+
69+
if cp_mesh is not None:
70+
# currently we only support these two SDP backends.
71+
# TODO (xilunwu): support cuDNN backend
72+
stack.enter_context(
73+
sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
74+
)
75+
stack.enter_context(
76+
context_parallel_ctx(
77+
buffers=cp_buffers,
78+
buffer_seq_dims=cp_seq_dims,
79+
no_restore_buffers=cp_no_restore_buffers,
80+
)
81+
)
82+
4283
yield
4384

4485
return context
@@ -70,6 +111,7 @@ def main(job_config: JobConfig):
70111
parallel_dims = ParallelDims(
71112
dp_shard=job_config.training.data_parallel_shard_degree,
72113
dp_replicate=job_config.training.data_parallel_replicate_degree,
114+
cp=job_config.experimental.context_parallel_degree,
73115
tp=job_config.training.tensor_parallel_degree,
74116
pp=job_config.experimental.pipeline_parallel_degree,
75117
world_size=world_size,
@@ -235,6 +277,7 @@ def loss_fn(pred, labels):
235277
train_context = get_train_context(
236278
parallel_dims.loss_parallel_enabled,
237279
job_config.experimental.enable_compiled_autograd,
280+
world_mesh["cp"] if parallel_dims.cp_enabled else None,
238281
)
239282

240283
# variables used to keep info for metrics logging
@@ -268,18 +311,24 @@ def loss_fn(pred, labels):
268311
data_load_start = time.perf_counter()
269312
batch = next(data_iterator)
270313
input_ids, labels = batch
271-
ntokens_since_last_log += labels.numel()
314+
ntokens_since_last_log += labels.numel() // parallel_dims.cp
272315
data_loading_times.append(time.perf_counter() - data_load_start)
273316

274317
input_ids = input_ids.cuda()
275318
labels = labels.cuda()
276319
optimizers.zero_grad()
277320

321+
training_context = train_context(
322+
cp_buffers=[input_ids, labels, model.freqs_cis],
323+
cp_seq_dims=[1, 1, 0],
324+
cp_no_restore_buffers={input_ids, labels},
325+
)
326+
278327
if parallel_dims.pp_enabled:
279328
# Pipeline Parallel forward / backward inside step() call
280329
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
281330

282-
with train_context():
331+
with training_context:
283332
if pp_mesh.get_local_rank() == 0:
284333
pp_schedule.step(input_ids)
285334
elif is_last_stage:
@@ -296,7 +345,7 @@ def loss_fn(pred, labels):
296345
)
297346
else:
298347
# Non-PP forward / backward
299-
with train_context():
348+
with training_context:
300349
pred = model(input_ids)
301350
loss = loss_fn(pred, labels)
302351
# pred.shape=(bs, seq_len, vocab_size)

train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ compile = false
4242
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4343

4444
[experimental]
45+
context_parallel_degree = 1
4546
pipeline_parallel_degree = 1
4647
enable_async_tensor_parallel = false
4748

0 commit comments

Comments
 (0)