Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 66 additions & 63 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,17 @@ def selective_checkpointing_context_fn():


def get_tp_parallel_strategy_for_transformer_block(
job_config: JobConfig,
model: nn.Module,
enable_float8: bool,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.

This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.enable_float8_linear:
# TODO(future PR): once float8 configuration supports delayed
if enable_float8:
# TODO(vkuzo): once float8 configuration supports delayed
# scaling, add a check here to enforce supported float8 all-gather
# configurations
# TODO(future PR): add the items below to __init__.py of torchao.float8,
# TODO(vkuzo): add the items below to __init__.py of torchao.float8,
# and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Expand All @@ -143,7 +142,7 @@ def get_tp_parallel_strategy_for_transformer_block(

def pipeline_llama(
model: nn.Module,
world_mesh: DeviceMesh,
pp_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
Expand All @@ -157,11 +156,11 @@ def pipeline_llama(
)
if split_mode == "manual":
return pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config, device, model_config
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
return pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config, device, model_config
model, pp_mesh, parallel_dims, job_config, device, model_config
)


Expand All @@ -184,7 +183,7 @@ def _mixed_precision_dtype(

def pipeline_llama_manual(
whole_model: nn.Module,
world_mesh: DeviceMesh,
pp_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
Expand All @@ -198,7 +197,6 @@ def pipeline_llama_manual(
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
parallelism.
"""
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
Expand Down Expand Up @@ -287,7 +285,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal

def pipeline_llama_tracer(
model: nn.Module,
world_mesh: DeviceMesh,
pp_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
Expand All @@ -306,7 +304,6 @@ def pipeline_llama_tracer(
"To work around, set mixed_precision_param to float32."
)

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
Expand Down Expand Up @@ -341,15 +338,12 @@ def pipeline_llama_tracer(

def apply_tp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
Expand Down Expand Up @@ -377,7 +371,7 @@ def apply_tp(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy_for_transformer_block(job_config, model)
) = get_tp_parallel_strategy_for_transformer_block(enable_float8)

# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
Expand Down Expand Up @@ -416,7 +410,7 @@ def apply_tp(
)

# updates expressly for async tensor parallel
if job_config.experimental.enable_async_tensor_parallel:
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._dynamo.config.cache_size_limit = 10000
Expand All @@ -434,16 +428,14 @@ def apply_tp(
job_config.training.compile = True

logger.info(
f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model"
f"Applied {'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
return model


def apply_ac(model: nn.Module, job_config: JobConfig):
def apply_ac(model: nn.Module, ac_config: JobConfig):
"""Apply activation checkpointing to the model."""

ac_config = job_config.activation_checkpoint

for layer_id, transformer_block in model.layers.named_children():
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
model.layers.register_module(layer_id, transformer_block)
Expand All @@ -452,14 +444,8 @@ def apply_ac(model: nn.Module, job_config: JobConfig):
return model


def apply_compile(model: nn.Module, job_config: JobConfig):
def apply_compile(model: nn.Module):
"""Apply torch.compile to each transformer block."""

if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
)

for layer_id, transformer_block in model.layers.named_children():
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
Expand All @@ -474,25 +460,19 @@ def apply_compile(model: nn.Module, job_config: JobConfig):

def apply_fsdp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
pp_enabled: bool,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

for layer_id, transformer_block in model.layers.items():
if parallel_dims.pp_enabled:
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
Expand All @@ -505,11 +485,9 @@ def apply_fsdp(
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

if parallel_dims.pp_enabled:
if pp_enabled:
# TODO
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
Expand All @@ -526,22 +504,19 @@ def apply_fsdp(

def apply_ddp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
dp_mesh: DeviceMesh,
enable_compile: bool,
enable_compiled_autograd: bool,
):
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism.")

if job_config.training.compile:
if job_config.experimental.enable_compiled_autograd:
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model
Expand All @@ -562,18 +537,46 @@ def parallelize_llama(
"""

if parallel_dims.tp_enabled:
model = apply_tp(model, world_mesh, parallel_dims, job_config)
model = apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.training.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

if job_config.activation_checkpoint.mode != "none":
model = apply_ac(model, job_config)
model = apply_ac(model, job_config.activation_checkpoint)

if job_config.training.compile:
model = apply_compile(model, job_config)
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
)
model = apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

model = apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
pp_enabled=parallel_dims.pp_enabled,
)
else:
model = apply_ddp(model, world_mesh, parallel_dims, job_config)
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism.")
model = apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)

return model
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def main(job_config: JobConfig):

if parallel_dims.pp_enabled:
stages, model_parts = models_pipelining_fns[model_name](
whole_model, world_mesh, parallel_dims, job_config, device, model_config
whole_model, pp_mesh, parallel_dims, job_config, device, model_config
)
else:
# In 1D/2D cases or PP with simple schedules, model_parts is just one item
Expand Down