diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 9e928781277..13740f3d3c5 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -1,10 +1,11 @@ import argparse from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, - DraftTargetDecodingConfig, EagleDecodingConfig, - KvCacheConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, TorchCompileConfig) +from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig, + CudaGraphConfig, DraftTargetDecodingConfig, + EagleDecodingConfig, KvCacheConfig, MoeConfig, + MTPDecodingConfig, NGramDecodingConfig, + TorchCompileConfig) example_prompts = [ "Hello, my name is", @@ -57,6 +58,13 @@ def add_llm_args(parser): parser.add_argument('--enable_attention_dp', default=False, action='store_true') + parser.add_argument('--attention_dp_enable_balance', + default=False, + action='store_true') + parser.add_argument('--attention_dp_time_out_iters', type=int, default=0) + parser.add_argument('--attention_dp_batching_wait_iters', + type=int, + default=0) parser.add_argument('--enable_trtllm_sampler', default=False, action='store_true') @@ -196,6 +204,12 @@ def setup_llm(args, **kwargs): enable_padding=args.cuda_graph_padding_enabled, ) if args.use_cuda_graph else None + attention_dp_config = AttentionDpConfig( + enable_balance=args.attention_dp_enable_balance, + timeout_iters=args.attention_dp_time_out_iters, + batching_wait_iters=args.attention_dp_batching_wait_iters, + ) + llm = LLM( model=args.model_dir, backend='pytorch', @@ -218,6 +232,7 @@ def setup_llm(args, **kwargs): max_batch_size=args.max_batch_size, max_num_tokens=args.max_num_tokens, enable_attention_dp=args.enable_attention_dp, + attention_dp_config=attention_dp_config, tensor_parallel_size=args.tp_size, pipeline_parallel_size=args.pp_size, moe_expert_parallel_size=args.moe_ep_size, diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..ff0fb204f1f 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -132,6 +132,9 @@ def __init__( self.pytorch_backend_config.enable_iter_perf_stats = False self.pytorch_backend_config.enable_iter_req_stats = False self.pytorch_backend_config.stream_interval = 1 + self.pytorch_backend_config.attention_dp_enable_balance = False + self.pytorch_backend_config.attention_dp_time_out_iters = 50 + self.pytorch_backend_config.attention_dp_batching_wait_iters = 10 self.iter_counter = 0 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 483d220c2e1..0770643ae35 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -46,6 +46,10 @@ class PyTorchConfig: moe_max_num_tokens: Optional[int] = None moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None + attention_dp_enable_balance: bool = False + attention_dp_time_out_iters: int = 50 + attention_dp_batching_wait_iters: int = 10 + attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index cbc13acd522..17d203be62c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -177,6 +177,9 @@ def __init__(self, self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats self.stream_interval = model_engine.pytorch_backend_config.stream_interval + self.attention_dp_enable_balance = model_engine.pytorch_backend_config.attention_dp_enable_balance + self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters + self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters self.num_fetch_requests_cur_rank = 0 self.num_fetch_requests = 0 self.shutdown_event = threading.Event() @@ -215,6 +218,9 @@ def __init__(self, self.draft_model_engine.warmup(self.resource_manager) self.is_shutdown = False + self.max_batch_size = max_batch_size + self.adp_ctx_waiting_iters_count = 0 + self.adp_ctx_batching_wait_iters_count = 0 # request fetcher initialization self.executor_request_queue = ExecutorRequestQueue( @@ -1131,13 +1137,68 @@ def _add_kv_cache_events(self): # to be transferred to main thread when user needs them. kv_cache_manager.flush_iteration_events() + def _balance_adp_requests(self, context_requests: list[LlmRequest], + generation_requests: list[LlmRequest]): + balanced_context_requests = context_requests + num_scheduled_context_requests = len(context_requests) + num_scheduled_generation_requests = len(generation_requests) + num_scheduled_tokens = sum( + [len(req.get_tokens(0)) + for req in context_requests]) + num_scheduled_generation_requests + responses_list = self.dist.tp_allgather([ + num_scheduled_context_requests, num_scheduled_generation_requests, + num_scheduled_tokens + ]) + all_ranks_num_scheduled_context_requests = [ + response[0] for response in responses_list + ] + all_ranks_num_scheduled_generation_requests = [ + response[1] for response in responses_list + ] + all_ranks_have_free_ctx_slots = all([ + num_gen < self.max_batch_size + for num_gen in all_ranks_num_scheduled_generation_requests + ]) + all_ranks_have_ctx_requests = all([ + num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests + ]) + all_ranks_have_gen_requests = all([ + num_gen > 0 + for num_gen in all_ranks_num_scheduled_generation_requests + ]) + + if self.attention_dp_enable_balance: + # wait for all ranks have context requests + if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests: + self.adp_ctx_waiting_iters_count = 0 + # balance number of context requests across ranks + if all_ranks_have_gen_requests: + if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters: + self.adp_ctx_batching_wait_iters_count += 1 + balanced_context_requests = [] + else: + self.adp_ctx_batching_wait_iters_count = 0 + else: + self.adp_ctx_waiting_iters_count += 1 + balanced_context_requests = [] + timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters + if timeout_reached or not all_ranks_have_gen_requests: + self.adp_ctx_waiting_iters_count = 0 + balanced_context_requests = context_requests + return balanced_context_requests + @nvtx_range("_schedule") def _schedule(self): scheduler_output = self.scheduler.schedule_request( self.active_requests, self.inflight_req_ids) - scheduled_requests = ScheduledRequests() + scheduled_context_requests = scheduler_output.context_requests + if self.enable_attention_dp and self.attention_dp_enable_balance: + scheduled_context_requests = self._balance_adp_requests( + scheduler_output.context_requests, + scheduler_output.generation_requests) - scheduled_requests.context_requests = scheduler_output.context_requests + scheduled_requests = ScheduledRequests() + scheduled_requests.context_requests = scheduled_context_requests scheduled_requests.generation_requests = scheduler_output.generation_requests scheduled_requests.paused_requests = scheduler_output.paused_requests return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index bef7ded9948..a5f49917886 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -4,15 +4,16 @@ from .build_cache import BuildCacheConfig from .llm import LLM, RequestOutput # yapf: disable -from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig, - CalibConfig, CapacitySchedulerPolicy, - ContextChunkingPolicy, CudaGraphConfig, - DraftTargetDecodingConfig, DynamicBatchConfig, - EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, - KvCacheConfig, LlmArgs, LookaheadDecodingConfig, - MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, SchedulerConfig, TorchCompileConfig, - TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) +from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType, + CacheTransceiverConfig, CalibConfig, + CapacitySchedulerPolicy, ContextChunkingPolicy, + CudaGraphConfig, DraftTargetDecodingConfig, + DynamicBatchConfig, EagleDecodingConfig, + ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, + LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, + MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, + TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, + UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mpi_session import MpiCommSession @@ -54,4 +55,5 @@ 'TorchLlmArgs', 'TrtLlmArgs', 'AutoDecodingConfig', + 'AttentionDpConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1c836264e22..b7d46ed6fa2 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -187,6 +187,23 @@ def from_dict(cls, data: dict): return cls(**data) +class AttentionDpConfig(StrictBaseModel): + """ + Configuration for attention DP. + """ + enable_balance: bool = Field(default=False, + description="Whether to enable balance.") + timeout_iters: int = Field( + default=50, description="The number of iterations to timeout.") + batching_wait_iters: int = Field( + default=10, + description="The number of iterations to wait for batching.") + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + @dataclass class _ParallelConfig: ''' The model distribution configs for LLM. ''' @@ -1988,6 +2005,11 @@ class TorchLlmArgs(BaseLlmArgs): Note that each CUDA graph can use up to 200 MB of extra memory.", status="beta") + attention_dp_config: Optional[AttentionDpConfig] = Field( + default=None, + description="Optimized load-balancing for the DP Attention scheduler.", + status="beta") + disable_overlap_scheduler: bool = Field( default=False, description="Disable the overlap scheduler.", @@ -2253,6 +2275,29 @@ def warn_on_unstable_feature_usage(self) -> 'TorchLlmArgs': return self + @model_validator(mode='after') + def validate_attention_dp_config(self) -> 'TorchLlmArgs': + """Validate attention DP configuration. + + Ensures that: + 1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater or equal to 0 + 2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater or equal to 0 + """ + if self.attention_dp_config is None: + return self + + config = self.attention_dp_config + if config.enable_balance: + if config.batching_wait_iters < 0: + raise ValueError( + "attention_dp_config.batching_wait_iters must be greater or equal to 0 when enable_balance is true" + ) + if config.timeout_iters < 0: + raise ValueError( + "attention_dp_config.timeout_iters must be greater or equal to 0 when enable_balance is true" + ) + return self + # TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig def get_pytorch_backend_config(self) -> "PyTorchConfig": from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig @@ -2303,7 +2348,16 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": enable_min_latency=self.enable_min_latency, stream_interval=self.stream_interval, force_dynamic_quantization=self.force_dynamic_quantization, - allreduce_strategy=self.allreduce_strategy) + allreduce_strategy=self.allreduce_strategy, + attention_dp_enable_balance=bool( + self.attention_dp_config is not None + and self.attention_dp_config.enable_balance), + attention_dp_time_out_iters=self.attention_dp_config.timeout_iters + if self.attention_dp_config is not None else + AttentionDpConfig.model_fields['timeout_iters'].default, + attention_dp_batching_wait_iters=self.attention_dp_config. + batching_wait_iters if self.attention_dp_config is not None else + AttentionDpConfig.model_fields['batching_wait_iters'].default) def update_llm_args_with_extra_dict( @@ -2320,6 +2374,7 @@ def update_llm_args_with_extra_dict( "speculative_config": DecodingBaseConfig, "lora_config": LoraConfig, "moe_config": MoeConfig, + "attention_dp_config": AttentionDpConfig, } for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index a6f44d431fd..4459213348b 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1800,6 +1800,45 @@ def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [27.0, 0, 0, 0]) +@skip_post_blackwell +@pytest.mark.skip_less_device_memory(80000) +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("model_name,model_path", [ + pytest.param( + 'DeepSeek-V3-Lite-FP8', 'DeepSeek-V3-Lite/fp8', marks=skip_pre_hopper), +]) +def test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance( + llm_root, llm_venv, model_name, model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--moe_tp_size=1", + "--moe_ep_size=4", + "--tp_size=4", + "--use_cuda_graph", + "--enable_attention_dp", + f"--kv_cache_fraction={_MEM_FRACTION_95}", + "--max_batch_size=1", + "--max_seq_len=3000", + "--disable_kv_cache_reuse", + "--attention_dp_enable_balance", + "--attention_dp_time_out_iters", + "10", + "--attention_dp_batching_wait_iters", + "10", + ], + stdout=running_log) + _check_mem_usage(running_log, [106.3, 0, 0, 0], 8) + + @skip_post_blackwell @pytest.mark.skip_less_device_memory(110000) @pytest.mark.skip_less_device(8) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 99fa084bda5..3a8e6aa9c98 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -51,6 +51,7 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2] - test_e2e.py::test_ptp_quickstart_advanced_bs1 + - test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8] - condition: ranges: system_gpu_count: diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 7f8485d097a..984c8953ecd 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -79,6 +79,10 @@ methods: annotation: Optional[tensorrt_llm.llmapi.llm_args.CudaGraphConfig] default: null status: beta + attention_dp_config: + annotation: Optional[tensorrt_llm.llmapi.llm_args.AttentionDpConfig] + default: null + status: beta checkpoint_loader: annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader] default: null