Skip to content

Commit 8a72389

Browse files
committed
Add support of DDP and experimental CompiledAutograd
Summary: Address the comments in #319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: e1019fb Pull Request resolved: #432
1 parent 3fca883 commit 8a72389

File tree

6 files changed

+73
-8
lines changed

6 files changed

+73
-8
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def estimate_memory(job_config: JobConfig):
6767
pp=job_config.experimental.pipeline_parallel_degree,
6868
world_size=world_size,
6969
enable_loss_parallel=job_config.training.enable_loss_parallel,
70+
dp_type=job_config.training.data_parallel_type,
7071
)
7172

7273
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")

test_runner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,17 @@ def build_test_list():
273273
"fsdp2_mem_tracker",
274274
ngpu=4,
275275
),
276+
OverrideDefinitions(
277+
[
278+
[
279+
"--training.data_parallel_type ddp",
280+
"--experimental.enable_compiled_autograd",
281+
]
282+
],
283+
"CompiledDDP",
284+
"compiled_ddp",
285+
ngpu=4,
286+
),
276287
]
277288
return integration_tests_flavors
278289

torchtitan/config_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ def __init__(self):
312312
The default value will be the number of pipeline stages, if unspecified.
313313
""",
314314
)
315+
self.parser.add_argument(
316+
"--training.data_parallel_type",
317+
type=str,
318+
default="fsdp",
319+
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
320+
)
321+
self.parser.add_argument(
322+
"--experimental.enable_compiled_autograd",
323+
action="store_true",
324+
help="Enable CompiledAutograd to compile the backward.",
325+
)
315326
self.parser.add_argument(
316327
"--training.mixed_precision_param",
317328
type=str,

torchtitan/parallelisms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class ParallelDims:
2828
pp: int
2929
world_size: int
3030
enable_loss_parallel: bool
31+
dp_type: str
3132

3233
def __post_init__(self):
34+
self.dp_type = self.dp_type.lower()
3335
self._validate()
3436

3537
def _validate(self):
@@ -42,6 +44,7 @@ def _validate(self):
4244
assert (
4345
dp * tp * pp == self.world_size
4446
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
47+
assert self.dp_type in ("fsdp", "ddp")
4548

4649
def build_mesh(self, device_type):
4750
dims = []

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from typing import Dict, Tuple
1313

1414
import torch
15-
1615
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
16+
17+
from torch.distributed._composable.replicate import replicate
1718
from torch.distributed._tensor import Replicate, Shard
1819
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1920
checkpoint_wrapper as ptd_checkpoint_wrapper,
@@ -452,7 +453,7 @@ def apply_compile(model, job_config: JobConfig):
452453
return model
453454

454455

455-
def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
456+
def apply_fsdp(model, world_mesh, parallel_dims, job_config: JobConfig):
456457
"""
457458
Apply data parallelism to the model. FSDP2 is used here.
458459
"""
@@ -489,6 +490,24 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
489490
return model
490491

491492

493+
def apply_ddp(model, world_mesh, parallel_dims, job_config: JobConfig):
494+
if world_mesh.ndim > 1:
495+
raise RuntimeError("DDP has not supported > 1D parallelism.")
496+
497+
if job_config.training.compile:
498+
if job_config.experimental.enable_compiled_autograd:
499+
torch._dynamo.config.optimize_ddp = (
500+
"python_reducer_without_compiled_forward"
501+
)
502+
else:
503+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
504+
505+
model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)
506+
507+
logger.info("Applied DDP to the model")
508+
return model
509+
510+
492511
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
493512
"""
494513
Apply tensor parallelism, activation checkpointing, torch.compile, and data
@@ -508,6 +527,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
508527
model = apply_compile(model, job_config)
509528

510529
if parallel_dims.dp_enabled:
511-
model = apply_dp(model, world_mesh, parallel_dims, job_config)
530+
if parallel_dims.dp_type == "fsdp":
531+
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
532+
else:
533+
model = apply_ddp(model, world_mesh, parallel_dims, job_config)
512534

513535
return model

train.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ def zero_grad(self):
135135
return OptimizersContainer([_build_optimizer(model) for model in model_parts])
136136

137137

138+
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
139+
@contextlib.contextmanager
140+
def context():
141+
with contextlib.ExitStack() as stack:
142+
if enable_loss_parallel:
143+
stack.enter_context(loss_parallel())
144+
if enable_compiled_autograd:
145+
stack.enter_context(
146+
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
147+
)
148+
149+
yield
150+
151+
return context
152+
153+
138154
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
139155
@record
140156
def main(job_config: JobConfig):
@@ -157,6 +173,7 @@ def main(job_config: JobConfig):
157173
pp=job_config.experimental.pipeline_parallel_degree,
158174
world_size=world_size,
159175
enable_loss_parallel=job_config.training.enable_loss_parallel,
176+
dp_type=job_config.training.data_parallel_type,
160177
)
161178
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
162179
torch.cuda.set_device(device)
@@ -191,9 +208,9 @@ def main(job_config: JobConfig):
191208
dp_rank,
192209
)
193210

194-
# loss_parallel enables dispatching to efficient loss operators
195-
loss_parallel_ctx = (
196-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
211+
train_context = get_train_context(
212+
parallel_dims.loss_parallel_enabled,
213+
job_config.experimental.enable_compiled_autograd,
197214
)
198215

199216
# loss fn can be shared by pipeline-parallel or non-pp execution
@@ -362,7 +379,7 @@ def loss_fn(pred, labels):
362379
# pipeline parallel forward / backward inside step() call
363380
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
364381

365-
with loss_parallel_ctx():
382+
with train_context():
366383
if pp_mesh.get_local_rank() == 0:
367384
pp_schedule.step(input_ids)
368385
elif is_last_stage:
@@ -379,7 +396,7 @@ def loss_fn(pred, labels):
379396
)
380397
else:
381398
# Non-PP forward / backward
382-
with loss_parallel_ctx():
399+
with train_context():
383400
pred = model(input_ids)
384401
loss = loss_fn(pred, labels)
385402
# pred.shape=(bs, seq_len, vocab_size)

0 commit comments

Comments
 (0)