From fc5b559ab5f903d72c027cba9a297c2c72710e0e Mon Sep 17 00:00:00 2001 From: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Date: Thu, 26 Jun 2025 11:32:04 -0700 Subject: [PATCH 1/2] [nvbug/5337601][fix] Fix disagg + speculative decoding Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2aa50df07f1..67ae9f14d59 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -869,12 +869,19 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() - if self.draft_model_engine is not None or is_ngram: - self._prepare_draft_requests() + if self.draft_model_engine is not None or is_ngram or self.drafter is not None: + self._prepare_draft_requests(self.active_requests) scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) + if self.draft_model_engine is not None or is_ngram or self.drafter is not None: + # REVIEW: This might need to be changed. The reason we call prepare_draft_requests + # on all active requests before scheduling is to make the scheduler aware of KV pages used + # by draft tokens. + self._prepare_draft_requests( + fitting_disagg_gen_init_requests) + if self.kv_cache_transceiver: # For requests that are fitting disagg gen init, also prepare resources for KV cache manager self._prepare_disagg_gen_init( @@ -966,13 +973,14 @@ def _executor_loop(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) - def _prepare_draft_requests(self): + def _prepare_draft_requests(self, requests): try: # Set draft tokens here to make the KV cache manager # and scheduler aware of them. - for req in self.active_requests: + for req in requests: # TODO: enable draft tokens in context phase - if req.state != LlmRequestState.GENERATION_IN_PROGRESS: + if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS, + LlmRequestState.DISAGG_GENERATION_INIT): continue req.py_last_draft_tokens = req.py_draft_tokens max_draft_len = self.model_engine.spec_config.max_draft_tokens @@ -1528,9 +1536,15 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): disagg_gen_init_to_prepare.generation_requests = [] disagg_gen_init_to_prepare.paused_requests = [] - self.resource_manager.resource_managers[ - ResourceManagerType.KV_CACHE_MANAGER].prepare_resources( - disagg_gen_init_to_prepare) + for resource_mgr_type in ( + ResourceManagerType.KV_CACHE_MANAGER, + ResourceManagerType.SEQ_SLOT_MANAGER, + ResourceManagerType.SPEC_RESOURCE_MANAGER, + ResourceManagerType.DRAFT_KV_CACHE_MANAGER): + if resource_mgr_type in self.resource_manager.resource_managers: + self.resource_manager.resource_managers[ + resource_mgr_type].prepare_resources( + disagg_gen_init_to_prepare) # Trigger KV cache exchange for new disagg_gen_init_requests self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests) @@ -1790,7 +1804,6 @@ def _prepare_draft_batch( # This is the first time the draft model is seeing this request. # Prepare a context request. We discard the first token and take # the newly decoded one - this is the convention for EAGLE 2 and 3. - assert num_draft_tokens == 0 new_request = LlmRequest( request_id=request.py_request_id, max_new_tokens=request.py_max_new_tokens, From a628337afb7092da469f9aa1c2e5884a61474520 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Fri, 27 Jun 2025 15:48:21 -0700 Subject: [PATCH 2/2] Minor edit Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Fix accuracy for disagg + eagle Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 14 +--- tensorrt_llm/_torch/pyexecutor/sampler.py | 46 ++++++------ .../accuracy/test_disaggregated_serving.py | 75 +++++++++++++++---- .../test_lists/test-db/l0_dgx_h100.yml | 1 + 4 files changed, 90 insertions(+), 46 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 67ae9f14d59..13508f85821 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -869,19 +869,13 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() - if self.draft_model_engine is not None or is_ngram or self.drafter is not None: + if self.draft_model_engine is not None or is_ngram or hasattr( + self, 'drafter') and self.drafter is not None: self._prepare_draft_requests(self.active_requests) scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) - if self.draft_model_engine is not None or is_ngram or self.drafter is not None: - # REVIEW: This might need to be changed. The reason we call prepare_draft_requests - # on all active requests before scheduling is to make the scheduler aware of KV pages used - # by draft tokens. - self._prepare_draft_requests( - fitting_disagg_gen_init_requests) - if self.kv_cache_transceiver: # For requests that are fitting disagg gen init, also prepare resources for KV cache manager self._prepare_disagg_gen_init( @@ -978,7 +972,6 @@ def _prepare_draft_requests(self, requests): # Set draft tokens here to make the KV cache manager # and scheduler aware of them. for req in requests: - # TODO: enable draft tokens in context phase if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue @@ -1541,7 +1534,8 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): ResourceManagerType.SEQ_SLOT_MANAGER, ResourceManagerType.SPEC_RESOURCE_MANAGER, ResourceManagerType.DRAFT_KV_CACHE_MANAGER): - if resource_mgr_type in self.resource_manager.resource_managers: + if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ + resource_mgr_type] is not None: self.resource_manager.resource_managers[ resource_mgr_type].prepare_resources( disagg_gen_init_to_prepare) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 4106c5976b4..0637d693d75 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -307,30 +307,30 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1): if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[token_idx] num_tokens = request.add_new_token(new_token, beam_idx) - if self._handle_stop_criteria(request, new_token, num_tokens, - beam_idx): - continue - - # Accept draft tokens (if we have any) if and only if they match the new - # token exactly. - num_accepted = 0 - new_tokens = [new_token] - for draft_token in request.py_draft_tokens: - if draft_token != new_token: - # Reject. - break - num_accepted += 1 - new_token = new_tokens_list[token_idx + num_accepted] - num_tokens = request.add_new_token(new_token, beam_idx) - new_tokens.append(num_tokens) # `num_tokens`->`new_token` - - if self._handle_stop_criteria(request, new_token, + if not self._handle_stop_criteria(request, new_token, num_tokens, beam_idx): - break - handle_logits(request, new_tokens, num_accepted) - request.py_decoding_iter += 1 - request.py_num_accepted_draft_tokens = num_accepted - request.py_rewind_len = request.py_draft_pages_allocated - num_accepted + + # Accept draft tokens (if we have any) if and only if they match the new + # token exactly. + num_accepted = 0 + new_tokens = [new_token] + for draft_token in request.py_draft_tokens: + if draft_token != new_token: + # Reject. + break + num_accepted += 1 + new_token = new_tokens_list[token_idx + num_accepted] + num_tokens = request.add_new_token(new_token, beam_idx) + new_tokens.append( + num_tokens) # `num_tokens`->`new_token` + + if self._handle_stop_criteria(request, new_token, + num_tokens, beam_idx): + break + handle_logits(request, new_tokens, num_accepted) + request.py_decoding_iter += 1 + request.py_num_accepted_draft_tokens = num_accepted + request.py_rewind_len = request.py_draft_pages_allocated - num_accepted advance_idx(len(request.py_draft_tokens) + 1) for request in generation_requests: diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 30d5d55b325..6c1af5288aa 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], trtllm_serve_path, model_name, "--host", "localhost", "--backend", "pytorch" ] + if tensor_parallel_size > 1: common_args.append(f"--tp_size={tensor_parallel_size}") @@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" env_gen["CUDA_VISIBLE_DEVICES"] = ",".join( map(str, range(tensor_parallel_size, 2 * tensor_parallel_size))) - - with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir, - popen(common_args + [ - "--port", "8001", "--extra_llm_api_options", - ctx_server_config_path - ], - env=env_ctx) as ctx_server, - popen(common_args + [ - "--port", "8002", "--extra_llm_api_options", - gen_server_config_path - ], - env=env_gen) as gen_server, + ctx_server_args = common_args + [ + "--port", "8001", "--extra_llm_api_options", ctx_server_config_path + ] + gen_server_args = common_args + [ + "--port", "8002", "--extra_llm_api_options", gen_server_config_path + ] + if "max_num_tokens" in ctx_server_config: + ctx_server_args.append( + f"--max_num_tokens={ctx_server_config['max_num_tokens']}") + if "max_num_tokens" in gen_server_config: + gen_server_args.append( + f"--max_num_tokens={gen_server_config['max_num_tokens']}") + + with (MyThreadPoolExecutor(max_workers=16) as + thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as + ctx_server, popen(gen_server_args, env=env_gen) as gen_server, popen([ trtllm_serve_path, "disaggregated", "-c", disaggregated_serving_config_path, "--server_start_timeout", @@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.parametrize("overlap_scheduler", [False]) + def test_eagle3(self, overlap_scheduler): + speculative_decoding_config = { + "decoding_type": "Eagle", + "max_draft_len": 4, + "pytorch_weights_path": + f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", + "eagle3_one_model": False + } + kv_cache_config = { + "free_gpu_memory_fraction": 0.5, + "enable_block_reuse": False + } + ctx_server_config = { + "disable_overlap_scheduler": True, + "speculative_config": speculative_decoding_config, + "kv_cache_config": kv_cache_config, + "max_num_tokens": 13393 * 2 + } + gen_server_config = { + "disable_overlap_scheduler": not overlap_scheduler, + "speculative_config": speculative_decoding_config, + "kv_cache_config": kv_cache_config, + "max_num_tokens": 13393 * 2 + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + -@pytest.mark.timeout(3600) @pytest.mark.skip_less_device_memory(140000) +@pytest.mark.timeout(3600) class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct" MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct" diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 489e5415d8e..7952636e501 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -39,6 +39,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False] - test_e2e.py::test_ptp_quickstart_advanced_bs1 - condition: ranges: