From 1e3f167fdac346b520da42339fdbb0921f60837c Mon Sep 17 00:00:00 2001 From: Less Wright Date: Sat, 20 Jul 2024 16:27:37 -0700 Subject: [PATCH 1/3] add required torch.compile cache for async tp, enhance ux --- torchtitan/parallelisms/parallelize_llama.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 634c70a0c..51091ced2 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -413,15 +413,21 @@ def apply_tp( parallelize_plan=layer_plan, ) + # updates expressly for async tensor parallel if job_config.experimental.enable_async_tensor_parallel: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + torch._dynamo.config.cache_size_limit = 10000 + logger.info("Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP") torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info("Applied Tensor Parallelism to the model") - return model + if not job_config.training.compile: + logger.warning(f"Async TP requires compilation...auto enabling compile = True for this job to resolve.") + job_config.training.compile = True + logger.info(f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model") + return model def apply_ac(model: nn.Module, job_config: JobConfig): """Apply activation checkpointing to the model.""" From d74df6a19201421c0013f68545222e68f1608628 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 20 Jul 2024 16:36:11 -0700 Subject: [PATCH 2/3] add enable_async_tp to debug_model config to showcase usage --- train_configs/debug_model.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 7c8499767..b36e9d0c7 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) [experimental] pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false From 65e0f96a6f3b443b316b1a61539f1681d2ef0a1c Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 20 Jul 2024 16:44:42 -0700 Subject: [PATCH 3/3] lint - fix line too long --- torchtitan/parallelisms/parallelize_llama.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 51091ced2..31eabc6c8 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -416,19 +416,27 @@ def apply_tp( # updates expressly for async tensor parallel if job_config.experimental.enable_async_tensor_parallel: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + torch._dynamo.config.cache_size_limit = 10000 - logger.info("Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP") + logger.info( + "Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP" + ) torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) if not job_config.training.compile: - logger.warning(f"Async TP requires compilation...auto enabling compile = True for this job to resolve.") + logger.warning( + "Async TP requires compilation...auto enabling compile = True for this job to resolve." + ) job_config.training.compile = True - logger.info(f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model") + logger.info( + f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model" + ) return model + def apply_ac(model: nn.Module, job_config: JobConfig): """Apply activation checkpointing to the model."""