diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 634c70a0c3..31eabc6c84 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -413,13 +413,27 @@ def apply_tp( parallelize_plan=layer_plan, ) + # updates expressly for async tensor parallel if job_config.experimental.enable_async_tensor_parallel: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + torch._dynamo.config.cache_size_limit = 10000 + logger.info( + "Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP" + ) + torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info("Applied Tensor Parallelism to the model") + if not job_config.training.compile: + logger.warning( + "Async TP requires compilation...auto enabling compile = True for this job to resolve." + ) + job_config.training.compile = True + + logger.info( + f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model" + ) return model diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 7c8499767a..b36e9d0c7b 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) [experimental] pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false