Skip to content

Commit 35880ff

Browse files
committed
enable Context Parallel
ghstack-source-id: 90f1bde Pull Request resolved: #592
1 parent eef8bb2 commit 35880ff

File tree

8 files changed

+103
-15
lines changed

8 files changed

+103
-15
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
6666
parallel_dims = ParallelDims(
6767
dp_shard=job_config.training.data_parallel_shard_degree,
6868
dp_replicate=job_config.training.data_parallel_replicate_degree,
69+
cp=job_config.experimental.context_parallel_degree,
6970
tp=job_config.training.tensor_parallel_degree,
7071
pp=job_config.experimental.pipeline_parallel_degree,
7172
world_size=world_size,

test_runner.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,29 @@ def build_test_list():
316316
"hsdp+tp",
317317
ngpu=8,
318318
),
319+
OverrideDefinitions(
320+
[
321+
[
322+
"--training.data_parallel_shard_degree=2",
323+
"--experimental.context_parallel_degree=2",
324+
]
325+
],
326+
"FSDP+CP",
327+
"fsdp+cp",
328+
ngpu=4,
329+
),
330+
OverrideDefinitions(
331+
[
332+
[
333+
"--training.data_parallel_shard_degree=2",
334+
"--training.data_parallel_replicate_degree=2",
335+
"--experimental.context_parallel_degree=2",
336+
]
337+
],
338+
"HSDP+CP",
339+
"hsdp+cp",
340+
ngpu=8,
341+
),
319342
OverrideDefinitions(
320343
[
321344
[

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,12 @@ def __init__(self):
326326
action="store_true",
327327
help="Enable CompiledAutograd to compile the backward.",
328328
)
329+
self.parser.add_argument(
330+
"--experimental.context_parallel_degree",
331+
type=int,
332+
default=1,
333+
help="Context parallelism degree. 1 means disabled.",
334+
)
329335
self.parser.add_argument(
330336
"--training.mixed_precision_param",
331337
type=str,

torchtitan/models/llama/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
415415
return precompute_freqs_cis(
416416
self.model_args.dim // self.model_args.n_heads,
417417
# Need to compute until at least the max token limit for generation
418-
# (use 2x max sequence length to be safe)
419-
self.model_args.max_seq_len * 2,
418+
# Note: removed the 2x relaxing in CP enablement
419+
self.model_args.max_seq_len,
420420
self.model_args.rope_theta,
421421
)
422422

torchtitan/parallelisms/parallel_dims.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class ParallelDims:
1616
dp_replicate: int
1717
dp_shard: int
18+
cp: int
1819
tp: int
1920
pp: int
2021
world_size: int
@@ -24,36 +25,38 @@ def __post_init__(self):
2425
self._validate()
2526

2627
def _validate(self):
27-
dp_replicate, dp_shard, tp, pp = (
28+
dp_replicate, dp_shard, cp, tp, pp = (
2829
self.dp_replicate,
2930
self.dp_shard,
31+
self.cp,
3032
self.tp,
3133
self.pp,
3234
)
33-
for d in (dp_replicate, tp, pp):
35+
for d in (dp_replicate, cp, tp, pp):
3436
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
3537
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
3638

3739
dp = dp_replicate * dp_shard
3840
if dp < 0:
39-
dp = self.world_size // (tp * pp)
41+
dp = self.world_size // (cp * tp * pp)
4042
self.dp_shard = dp_shard = dp // dp_replicate
4143

4244
assert dp_replicate >= 1
4345
assert dp_shard >= 1
46+
assert cp >= 1, cp
4447
assert tp >= 1, tp
4548
assert pp >= 1, pp
46-
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
49+
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
4750
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
48-
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
51+
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
4952
)
5053

5154
def build_mesh(self, device_type):
5255
dims = []
5356
names = []
5457
for d, name in zip(
55-
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
56-
["pp", "dp_replicate", "dp_shard", "tp"],
58+
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
59+
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5760
strict=True,
5861
):
5962
if d > 1:
@@ -86,6 +89,10 @@ def dp_replicate_enabled(self):
8689
def dp_shard_enabled(self):
8790
return self.dp_shard > 1
8891

92+
@property
93+
def cp_enabled(self):
94+
return self.cp > 1
95+
8996
@property
9097
def tp_enabled(self):
9198
return self.tp > 1

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
99

1010
from collections import defaultdict
11+
from typing import Tuple
1112

1213
import torch
1314
import torch.nn as nn
15+
1416
from torch.distributed import DeviceMesh
1517
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1618
from torch.distributed._composable.replicate import replicate
@@ -72,13 +74,18 @@ def parallelize_llama(
7274
)
7375
apply_compile(model)
7476

75-
if parallel_dims.dp_enabled:
77+
if parallel_dims.dp_enabled or parallel_dims.cp_enabled:
7678
if parallel_dims.dp_shard_enabled:
7779
if parallel_dims.dp_replicate_enabled:
7880
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
7981
else:
8082
dp_mesh = world_mesh["dp"]
8183

84+
if parallel_dims.cp_enabled:
85+
dp_dim_names = dp_mesh.mesh_dim_names
86+
assert isinstance(dp_dim_names, Tuple)
87+
dp_mesh = world_mesh[(*dp_dim_names, "cp")]._flatten()
88+
8289
apply_fsdp(
8390
model,
8491
dp_mesh,

train.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from datetime import timedelta
1111

1212
import torch
13+
14+
from typing import List, Optional, Set
15+
from functools import partial
16+
17+
from torch.distributed.device_mesh import DeviceMesh
1318
from torch.distributed.elastic.multiprocessing.errors import record
1419

1520
from torchtitan import utils
@@ -28,17 +33,47 @@
2833
)
2934
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
3035

36+
try:
37+
from torch.distributed.tensor.experimental import context_parallel
38+
except ImportError:
39+
print(
40+
f"PyTorch version {torch.__version__} does not include the experimental "
41+
"Context Parallel API. Please update to a newer version."
42+
)
43+
44+
45+
def get_train_context(
46+
enable_loss_parallel: bool,
47+
enable_compiled_autograd: bool,
48+
cp_mesh: Optional[DeviceMesh] = None,
49+
):
50+
if cp_mesh is not None:
51+
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)
3152

32-
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
3353
@contextlib.contextmanager
34-
def context():
54+
def context(
55+
cp_buffers: List[torch.Tensor],
56+
cp_seq_dims: List[int],
57+
cp_no_restore_buffers: Set[torch.Tensor],
58+
):
3559
with contextlib.ExitStack() as stack:
3660
if enable_loss_parallel:
3761
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
62+
3863
if enable_compiled_autograd:
3964
stack.enter_context(
4065
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
4166
)
67+
68+
if cp_mesh is not None:
69+
stack.enter_context(
70+
context_parallel_ctx(
71+
buffers=cp_buffers,
72+
buffer_seq_dims=cp_seq_dims,
73+
no_restore_buffers=cp_no_restore_buffers,
74+
)
75+
)
76+
4277
yield
4378

4479
return context
@@ -61,6 +96,7 @@ def main(job_config: JobConfig):
6196
parallel_dims = ParallelDims(
6297
dp_shard=job_config.training.data_parallel_shard_degree,
6398
dp_replicate=job_config.training.data_parallel_replicate_degree,
99+
cp=job_config.experimental.context_parallel_degree,
64100
tp=job_config.training.tensor_parallel_degree,
65101
pp=job_config.experimental.pipeline_parallel_degree,
66102
world_size=world_size,
@@ -226,6 +262,7 @@ def loss_fn(pred, labels):
226262
train_context = get_train_context(
227263
parallel_dims.loss_parallel_enabled,
228264
job_config.experimental.enable_compiled_autograd,
265+
world_mesh["cp"] if parallel_dims.cp_enabled else None,
229266
)
230267

231268
# variables used to keep info for metrics logging
@@ -259,18 +296,24 @@ def loss_fn(pred, labels):
259296
data_load_start = time.perf_counter()
260297
batch = next(data_iterator)
261298
input_ids, labels = batch
262-
ntokens_since_last_log += labels.numel()
299+
ntokens_since_last_log += labels.numel() // parallel_dims.cp
263300
data_loading_times.append(time.perf_counter() - data_load_start)
264301

265302
input_ids = input_ids.cuda()
266303
labels = labels.cuda()
267304
optimizers.zero_grad()
268305

306+
training_context = train_context(
307+
cp_buffers=[input_ids, labels, model.freqs_cis],
308+
cp_seq_dims=[1, 1, 0],
309+
cp_no_restore_buffers={input_ids, labels},
310+
)
311+
269312
if parallel_dims.pp_enabled:
270313
# Pipeline Parallel forward / backward inside step() call
271314
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
272315

273-
with train_context():
316+
with training_context:
274317
if pp_mesh.get_local_rank() == 0:
275318
pp_schedule.step(input_ids)
276319
elif is_last_stage:
@@ -287,7 +330,7 @@ def loss_fn(pred, labels):
287330
)
288331
else:
289332
# Non-PP forward / backward
290-
with train_context():
333+
with training_context:
291334
pred = model(input_ids)
292335
loss = loss_fn(pred, labels)
293336
# pred.shape=(bs, seq_len, vocab_size)

train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ compile = false
4242
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4343

4444
[experimental]
45+
context_parallel_degree = 1
4546
pipeline_parallel_degree = 1
4647
enable_async_tensor_parallel = false
4748

0 commit comments

Comments
 (0)