From 0c3e7bf297e03d0775d92c5e12a0b88246bcaa75 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 27 Jun 2024 21:35:59 -0700 Subject: [PATCH 1/7] [Dist][Inference] U-haul TP and distribute utils code to TorchChat --- build/builder.py | 23 +++- distributed/__init__.py | 2 + distributed/parallel_config.py | 48 ++++++++ distributed/parallelize_llama.py | 119 ++++++++++++++++++++ distributed/utils.py | 182 +++++++++++++++++++++++++++++++ 5 files changed, 368 insertions(+), 6 deletions(-) create mode 100644 distributed/__init__.py create mode 100644 distributed/parallel_config.py create mode 100644 distributed/parallelize_llama.py create mode 100644 distributed/utils.py diff --git a/build/builder.py b/build/builder.py index 409013ceb..c7453ce16 100644 --- a/build/builder.py +++ b/build/builder.py @@ -21,6 +21,7 @@ from build.model import Transformer from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype +from distributed import parallelize_llama, ParallelDims, ParallelConfig @dataclass @@ -36,7 +37,7 @@ class BuilderArgs: device: Optional[str] = None precision: torch.dtype = torch.float32 setup_caches: bool = False - use_tp: bool = False + use_distributed: bool = False is_chat_model: bool = False prefill_possible: bool = False @@ -141,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs: device=args.device, precision=dtype, setup_caches=(args.output_dso_path or args.output_pte_path), - use_tp=False, + use_distributed=False, is_chat_model=is_chat_model, ) @@ -346,11 +347,21 @@ def _load_model(builder_args, only_config=False): else: model = _load_model_default(builder_args) - if builder_args.use_tp: - from tp import apply_tp + if builder_args.use_distributed: + # init distributed + world_size = int(os.environ["WORLD_SIZE"]) + parallel_config = ParallelConfig() + parallel_dims = ParallelDims( + tp=parallel_config.tp_degree, + pp=parallel_config.pp_degree, + world_size=world_size, + ) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) + init_distributed(job_config) - print("Applying tensor parallel to model ...") - apply_tp(model) + print("Applying model parallel to model ...") + parallelize_llama(model) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() diff --git a/distributed/__init__.py b/distributed/__init__.py new file mode 100644 index 000000000..8b8ddb92d --- /dev/null +++ b/distributed/__init__.py @@ -0,0 +1,2 @@ +from distributed.parallelize_llama import parallelize_llama +from distributed.parallel_config import ParallelConfig, ParallelDims diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py new file mode 100644 index 000000000..eb90e36fe --- /dev/null +++ b/distributed/parallel_config.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from torch.distributed.device_mesh import init_device_mesh + +@dataclass +class ParallelConfig: + name: str = field(default="") + fp8_linear: str = field(default="") + tp_degree: int = field(default=1) + pp_degree: int = field(default=1) + + +@dataclass +class ParallelDims: + tp: int + pp: int + world_size: int + + def __post_init__(self): + self._validate() + + def _validate(self): + tp, pp = self.tp, self.pp + assert tp >= 1, tp + assert pp >= 1, pp + assert ( + tp * pp == self.world_size + ), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + + def build_mesh(self, device_type): + dims = [] + names = [] + for d, name in zip( + [self.pp, self.tp], ["pp", "tp"], strict=True + ): + if d > 1: + dims.append(d) + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + return init_device_mesh(device_type, dims, mesh_dim_names=names) + + @property + def tp_enabled(self): + return self.tp > 1 + + @property + def pp_enabled(self): + return self.pp > 1 diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py new file mode 100644 index 000000000..4ebaea026 --- /dev/null +++ b/distributed/parallelize_llama.py @@ -0,0 +1,119 @@ +from typing import Tuple +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from distributed.parallel_config import ParallelConfig + + +def get_tp_parallel_strategy( + config: ParallelConfig, +) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: + """Get the parallel strategy for the transformer model. + + This function handles the special case of using float8 with tensor parallelism. + """ + if config.fp8_linear == "dynamic": + from float8_experimental.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput + return RowwiseParallel, ColwiseParallel, PrepareModuleInput + + +def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): + """ + Apply tensor parallelism. + """ + + tp_mesh = world_mesh["tp"] + ( + row_parallel_strategy, + col_parallel_strategy, + prepare_module_input, + ) = get_tp_parallel_strategy(config) + loss_parallel = parallel_dims.loss_parallel_enabled + + # 1. Parallelize the first embedding and the last linear proj layer + # 2. Parallelize the root norm layer over the sequence dim + # 3. Shard the first transformer block's inputs + model = parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "output": col_parallel_strategy( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + "norm": SequenceParallel(), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": col_parallel_strategy(), + "attention.wk": col_parallel_strategy(), + "attention.wv": col_parallel_strategy(), + "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": col_parallel_strategy(), + "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), + "feed_forward.w3": col_parallel_strategy(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info("Applied Tensor Parallelism to the model") + return model + + + + +def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + model = apply_tp(model, world_mesh, parallel_dims, job_config) + + # only enable TP for now. + # if job_config.training.compile: + # model = apply_compile(model, job_config) + + return model diff --git a/distributed/utils.py b/distributed/utils.py new file mode 100644 index 000000000..c29836601 --- /dev/null +++ b/distributed/utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from dataclasses import dataclass +from datetime import timedelta +from typing import Union + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh +from torchtitan.logging_utils import logger +from torchtitan.parallelisms import ParallelDims + + +def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + + +def _warn_overwrite_env(env, val): + if env in os.environ: + logger.warning( + f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" + ) + os.environ[env] = val + + +def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: + """ + Returns global rank 0 in non-pipeline-parallel configs, and returns the global + rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. + """ + if parallel_dims.pp_enabled: + assert ( + world_mesh.mesh_dim_names[0] == "pp" + ), "get_metrics_rank assumes pp is the outer mesh dim" + pp_mesh = world_mesh["pp"] + pp_size = pp_mesh.size() + metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) + else: + metrics_log_rank = 0 + + return metrics_log_rank + + +def set_pg_timeouts(timeout, world_mesh): + """ + Sets the timeout for all PGs in the provided mesh, and the default (world) group. + + Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase + otherwise you may face a race where the slow rank has not reached the timeout reduction point + yet due to slow operations permitted under the old timeout value, but other faster ranks may + start issueing collectives under the new shorter timeout and then immediately timeout. + """ + logger.info( + f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}" + ) + # Ensure that all the ranks have reached the point of setting the new timeout- + # otherwise, some ranks may issue collectives with the new/shorter timeout and + # those may time out, before other ranks have finished with initialization done + # under the old/slow timeout. + torch.distributed.barrier() + torch.cuda.synchronize() + + groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] + + # None represents the 'default' PG, not part of the mesh + groups.append(None) + for group in groups: + torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) + + +TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" +TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" +DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" +ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" +SKIP_CLEANUP = "3" + + +def init_distributed(job_config): + # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) + # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 + # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle + # behavior differences + _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) + + # enable torch nccl flight recorder in the mode that would dump files if timeout is detected + _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) + if job_config.comm.trace_buf_size > 0: + # dump on timeout by default if trace buffer is enabled + _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") + dump_dir = f"{job_config.job.dump_folder}/comm_trace" + os.makedirs(dump_dir, exist_ok=True) + _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") + + torch.distributed.init_process_group( + "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + ) + + # to mitigate the memory issue that collectives using + # async_op=True hold memory longer than they should + # such as those in tensor parallelism + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + +def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: + num_params = sum(p.numel() for p in model.parameters()) + if exclude_embedding: + num_params -= model.tok_embeddings.weight.numel() + return num_params + + +def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: + l, h, q, t = ( + model_config.n_layers, + model_config.n_heads, + model_config.dim // model_config.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + flop_per_token = 6 * num_params + 12 * l * h * q * t + + return flop_per_token + + +# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU +def get_peak_flops(device_name: str) -> int: + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 1979e12 + elif "PCIe" in device_name: + return 756e12 + else: # for SXM and other variants + return 989e12 + else: # for other GPU types, assume A100 + return 312e12 + + +@dataclass(frozen=True) +class Color: + black = "\033[30m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + magenta = "\033[35m" + cyan = "\033[36m" + white = "\033[37m" + reset = "\033[39m" + + +@dataclass(frozen=True) +class NoColor: + black = "" + red = "" + green = "" + yellow = "" + blue = "" + magenta = "" + cyan = "" + white = "" + reset = "" From 1b8a8804af7f01060ec543e09b62dc8380d43260 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 27 Jun 2024 21:58:34 -0700 Subject: [PATCH 2/7] Remove unnecessary code and add comment --- distributed/__init__.py | 6 ++ distributed/parallel_config.py | 6 ++ distributed/parallelize_llama.py | 9 ++- distributed/utils.py | 104 ------------------------------- 4 files changed, 19 insertions(+), 106 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index 8b8ddb92d..50d497d1d 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -1,2 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from distributed.parallelize_llama import parallelize_llama from distributed.parallel_config import ParallelConfig, ParallelDims diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py index eb90e36fe..f5f0d787b 100644 --- a/distributed/parallel_config.py +++ b/distributed/parallel_config.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from dataclasses import dataclass, field from torch.distributed.device_mesh import init_device_mesh diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index 4ebaea026..013f1dcf5 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from typing import Tuple from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -86,6 +92,7 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): # Adjust attention module to use the local number of heads attn_layer = transformer_block.attention attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() parallelize_module( @@ -98,8 +105,6 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): return model - - def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data diff --git a/distributed/utils.py b/distributed/utils.py index c29836601..ec6699acf 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -5,26 +5,9 @@ # LICENSE file in the root directory of this source tree. import os -from dataclasses import dataclass from datetime import timedelta -from typing import Union import torch -import torch.distributed._functional_collectives as funcol -import torch.distributed.distributed_c10d as c10d -from torch.distributed.device_mesh import DeviceMesh -from torchtitan.logging_utils import logger -from torchtitan.parallelisms import ParallelDims - - -def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) - - -def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) def _warn_overwrite_env(env, val): @@ -35,24 +18,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: - """ - Returns global rank 0 in non-pipeline-parallel configs, and returns the global - rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. - """ - if parallel_dims.pp_enabled: - assert ( - world_mesh.mesh_dim_names[0] == "pp" - ), "get_metrics_rank assumes pp is the outer mesh dim" - pp_mesh = world_mesh["pp"] - pp_size = pp_mesh.size() - metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) - else: - metrics_log_rank = 0 - - return metrics_log_rank - - def set_pg_timeouts(timeout, world_mesh): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -111,72 +76,3 @@ def init_distributed(job_config): # async_op=True hold memory longer than they should # such as those in tensor parallelism os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - -def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: - num_params = sum(p.numel() for p in model.parameters()) - if exclude_embedding: - num_params -= model.tok_embeddings.weight.numel() - return num_params - - -def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: - l, h, q, t = ( - model_config.n_layers, - model_config.n_heads, - model_config.dim // model_config.n_heads, - seq_len, - ) - # Reasoning behind the factor of 12 for the self-attention part of the formula: - # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) - # 2. the flash attention does 1 more matmul recomputation in the backward - # but recomputation should not be counted in calculating MFU (+0) - # 3. each matmul performs 1 multiplication and 1 addition (*2) - # 4. we follow the convention and do not account for sparsity in causal attention - flop_per_token = 6 * num_params + 12 * l * h * q * t - - return flop_per_token - - -# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU -def get_peak_flops(device_name: str) -> int: - if "A100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/a100/ - return 312e12 - elif "H100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/h100/ - # NOTE: Specifications are one-half lower without sparsity. - if "NVL" in device_name: - return 1979e12 - elif "PCIe" in device_name: - return 756e12 - else: # for SXM and other variants - return 989e12 - else: # for other GPU types, assume A100 - return 312e12 - - -@dataclass(frozen=True) -class Color: - black = "\033[30m" - red = "\033[31m" - green = "\033[32m" - yellow = "\033[33m" - blue = "\033[34m" - magenta = "\033[35m" - cyan = "\033[36m" - white = "\033[37m" - reset = "\033[39m" - - -@dataclass(frozen=True) -class NoColor: - black = "" - red = "" - green = "" - yellow = "" - blue = "" - magenta = "" - cyan = "" - white = "" - reset = "" From fd378c9b84b139150e110af94e7cf52d2c29f8d5 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 10:01:37 -0700 Subject: [PATCH 3/7] Add Torchrun script and enable distributed for that script --- build/builder.py | 3 ++- cli.py | 5 +++++ config/model_config.py | 1 + distributed/run_dist_inference.sh | 31 +++++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100755 distributed/run_dist_inference.sh diff --git a/build/builder.py b/build/builder.py index c7453ce16..1139c0ab1 100644 --- a/build/builder.py +++ b/build/builder.py @@ -142,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs: device=args.device, precision=dtype, setup_caches=(args.output_dso_path or args.output_pte_path), - use_distributed=False, + use_distributed=args.distributed, is_chat_model=is_chat_model, ) @@ -347,6 +347,7 @@ def _load_model(builder_args, only_config=False): else: model = _load_model_default(builder_args) + # TODO: ongoing work to support loading model from checkpoint if builder_args.use_distributed: # init distributed world_size = int(os.environ["WORLD_SIZE"]) diff --git a/cli.py b/cli.py index 24f6d6ed0..3c8a503d7 100644 --- a/cli.py +++ b/cli.py @@ -56,6 +56,11 @@ def add_arguments_for_verb(parser, verb: str): action="store_true", help="Whether to start an interactive chat session", ) + parser.add_argument( + "--distributed", + action="store_true", + help="Whether to enable distributed inference", + ) parser.add_argument( "--gui", action="store_true", diff --git a/config/model_config.py b/config/model_config.py index aa6f24e79..2e479beb7 100644 --- a/config/model_config.py +++ b/config/model_config.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + import json from dataclasses import dataclass, field from enum import Enum diff --git a/distributed/run_dist_inference.sh b/distributed/run_dist_inference.sh new file mode 100755 index 000000000..3268750d9 --- /dev/null +++ b/distributed/run_dist_inference.sh @@ -0,0 +1,31 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# libUV is a scalable backend for TCPStore which is used in processGroup +# rendezvous. This is the recommended backend for distributed training. +export USE_LIBUV=1 + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh + +NGPU=${NGPU:-"8"} + +# TODO: We need to decide how to log for inference. +# by default log just rank 0 output, +LOG_RANK=${LOG_RANK:-0} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +torchchat.py chat llama3 --distributed $overrides From 8774c3442066afc0cfa02037e390d2c2c9475660 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 10:06:11 -0700 Subject: [PATCH 4/7] Remove unnecessary changes --- config/model_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/config/model_config.py b/config/model_config.py index 2e479beb7..aa6f24e79 100644 --- a/config/model_config.py +++ b/config/model_config.py @@ -3,7 +3,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import json from dataclasses import dataclass, field from enum import Enum From b6823a38b596041d9fe1d2bdc8ee91a263bef5b1 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 11:02:39 -0700 Subject: [PATCH 5/7] Remove ununsed function --- distributed/utils.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index ec6699acf..71b68f94a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -18,33 +18,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -def set_pg_timeouts(timeout, world_mesh): - """ - Sets the timeout for all PGs in the provided mesh, and the default (world) group. - - Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase - otherwise you may face a race where the slow rank has not reached the timeout reduction point - yet due to slow operations permitted under the old timeout value, but other faster ranks may - start issueing collectives under the new shorter timeout and then immediately timeout. - """ - logger.info( - f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}" - ) - # Ensure that all the ranks have reached the point of setting the new timeout- - # otherwise, some ranks may issue collectives with the new/shorter timeout and - # those may time out, before other ranks have finished with initialization done - # under the old/slow timeout. - torch.distributed.barrier() - torch.cuda.synchronize() - - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - - # None represents the 'default' PG, not part of the mesh - groups.append(None) - for group in groups: - torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) - - TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" From 910261f351eac3157765fa2be2ee77734a29d3fa Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 16:01:45 -0700 Subject: [PATCH 6/7] Add comments and further clean up the code --- build/builder.py | 8 +-- distributed/__init__.py | 2 +- distributed/parallel_config.py | 8 --- distributed/parallelize_llama.py | 93 +++++++++++++++++--------------- 4 files changed, 55 insertions(+), 56 deletions(-) diff --git a/build/builder.py b/build/builder.py index 1139c0ab1..bd3ef5f4a 100644 --- a/build/builder.py +++ b/build/builder.py @@ -21,7 +21,7 @@ from build.model import Transformer from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype -from distributed import parallelize_llama, ParallelDims, ParallelConfig +from distributed import parallelize_llama, ParallelDims @dataclass @@ -351,10 +351,10 @@ def _load_model(builder_args, only_config=False): if builder_args.use_distributed: # init distributed world_size = int(os.environ["WORLD_SIZE"]) - parallel_config = ParallelConfig() + # TODO: To make tp, pp degree configurable parallel_dims = ParallelDims( - tp=parallel_config.tp_degree, - pp=parallel_config.pp_degree, + tp=8, + pp=1, world_size=world_size, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/distributed/__init__.py b/distributed/__init__.py index 50d497d1d..64cd5f22d 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -5,4 +5,4 @@ # LICENSE file in the root directory of this source tree. from distributed.parallelize_llama import parallelize_llama -from distributed.parallel_config import ParallelConfig, ParallelDims +from distributed.parallel_config import ParallelDims diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py index f5f0d787b..d1d8aa9c7 100644 --- a/distributed/parallel_config.py +++ b/distributed/parallel_config.py @@ -7,14 +7,6 @@ from dataclasses import dataclass, field from torch.distributed.device_mesh import init_device_mesh -@dataclass -class ParallelConfig: - name: str = field(default="") - fp8_linear: str = field(default="") - tp_degree: int = field(default=1) - pp_degree: int = field(default=1) - - @dataclass class ParallelDims: tp: int diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index 013f1dcf5..a64527e12 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -13,39 +13,35 @@ SequenceParallel, ) -from distributed.parallel_config import ParallelConfig +import torch.nn as nn +from distributed.parallel_config import ParallelDims +from torch.distributed.device_mesh import DeviceMesh -def get_tp_parallel_strategy( - config: ParallelConfig, -) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: - """Get the parallel strategy for the transformer model. - - This function handles the special case of using float8 with tensor parallelism. - """ - if config.fp8_linear == "dynamic": - from float8_experimental.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, - ) - - return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput - return RowwiseParallel, ColwiseParallel, PrepareModuleInput - - -def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): +def apply_tp( + model: nn.Module, + world_mesh: DeviceMesh, +) -> nn.Module: """ - Apply tensor parallelism. + Apply tensor parallelism to the given model. More details can be + found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html. + + NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model. + One needs to change the ``parallelize_plan`` we pass in to the TP api if the model + is not a LLaMA model. + + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + world_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + Return: + A :class:`nn.Module` object tensor-parallelized. """ tp_mesh = world_mesh["tp"] - ( - row_parallel_strategy, - col_parallel_strategy, - prepare_module_input, - ) = get_tp_parallel_strategy(config) - loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the first embedding and the last linear proj layer # 2. Parallelize the root norm layer over the sequence dim @@ -58,10 +54,10 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): input_layouts=Replicate(), output_layouts=Shard(1), ), - "output": col_parallel_strategy( + "output": ColwiseParallel( input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, + output_layouts=Replicate(), + use_local_output=True, ), "norm": SequenceParallel(), }, @@ -74,18 +70,18 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), - "attention.wq": col_parallel_strategy(), - "attention.wk": col_parallel_strategy(), - "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "attention_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), - "feed_forward.w3": col_parallel_strategy(), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), "ffn_norm": SequenceParallel(), } @@ -105,20 +101,31 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig): return model -def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig): +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, +) -> nn.Module: """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + world_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + parallel_dims (:class:`ParallelDims`): + The object of the util class which contains the degree for each parallelism. + Return: + A :class:`nn.Module` object parallelized. """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims, job_config) - - # only enable TP for now. - # if job_config.training.compile: - # model = apply_compile(model, job_config) + model = apply_tp(model, world_mesh, parallel_dims) return model From 8393020aafa6e13e5a6b24caf8e8590ced5a8174 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 19:51:12 -0700 Subject: [PATCH 7/7] Edit comments --- distributed/parallelize_llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index a64527e12..e2b73d0dd 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -107,8 +107,7 @@ def parallelize_llama( parallel_dims: ParallelDims, ) -> nn.Module: """ - Apply tensor parallelism, activation checkpointing, torch.compile, and data - parallelism to the model. + Apply tensor parallelism and other parallelism(TODO) to the model for inference. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory.