From 238114263b8716553910d4c5e7e65fc03b97b82b Mon Sep 17 00:00:00 2001 From: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:28:48 -0700 Subject: [PATCH 1/2] [feat] Support chunked prefill on spec decode 2 model Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 ++ .../_torch/speculative/model_drafter.py | 46 ++++++++++++++++--- .../_torch/speculative/test_eagle3.py | 42 +++++++++++------ 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 461c5de941e..7a7e4510dd0 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -303,6 +303,7 @@ def __init__( self.py_batch_idx = None self.py_rewind_len = 0 self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens + self.py_last_context_chunk = (None, None) self.py_last_draft_tokens = None self.py_num_accepted_draft_tokens = 0 self.py_decoding_iter = 0 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e5b302310fc..91a76f80319 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1723,6 +1723,10 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): for request in scheduled_requests.context_requests: if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests + request.py_last_context_chunk = ( + request.context_current_position, + request.context_current_position + + request.context_chunk_size) request.move_to_next_context_chunk() if request.context_remaining_length == 0: request.state = LlmRequestState.GENERATION_IN_PROGRESS diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index ac195ccf515..bbdfc7a24d3 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -76,10 +76,17 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: def _create_context_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: """Create a context request for first-time drafting.""" - return self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens, request.sampling_config, - request.return_perf_metrics) + new_request = self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens, + request.sampling_config, + request.return_perf_metrics) + + begin_compute, end_compute = request.py_last_context_chunk + if begin_compute is not None: + new_request.context_current_position = begin_compute + new_request.context_chunk_size = end_compute - begin_compute + return new_request def _create_generation_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: @@ -94,10 +101,13 @@ def _create_generation_request(self, request: LlmRequest, new_request.state = LlmRequestState.GENERATION_IN_PROGRESS return new_request - def _create_chunked_context_request(self, request: LlmRequest, + def _create_accepted_tokens_request(self, request: LlmRequest, input_tokens: Any, num_accepted_tokens: int) -> LlmRequest: - """Create a chunked context request when some tokens were accepted.""" + """ + Create a chunked context request for accepted tokens. + Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3) + """ new_request = self._create_draft_request(request.py_request_id, request.py_max_new_tokens, input_tokens, @@ -130,7 +140,7 @@ def _create_draft_request_for_request( # Tokens accepted - chunked context request else: - return self._create_chunked_context_request(request, input_tokens, + return self._create_accepted_tokens_request(request, input_tokens, num_accepted_tokens) def _add_to_draft_batch(self, draft_batch: ScheduledRequests, @@ -168,6 +178,22 @@ def _prepare_draft_batch( try: draft_batch = ScheduledRequests() + for request in scheduled_requests.context_requests: + if request.is_first_context_chunk: + # Ignore requests which still need to be processed by the target model. + continue + + # We hit this path if we're doing chunked prefill. The target model processed + # a prefill chunk on the last iteration. Now, we need to fill in the KV cache + # for the draft model too. + all_tokens = request.get_tokens()[0] + input_tokens = self.spec_config.get_draft_model_prompt( + all_tokens) + + new_request = self._create_context_request( + request, input_tokens) + self._add_to_draft_batch(draft_batch, new_request, request) + for request in scheduled_requests.generation_requests: if request.py_draft_pages_allocated == 0: # No space for draft tokens @@ -257,6 +283,12 @@ def _process_decoded_tokens( new_requests = [] for req in draft_batch.all_requests(): target_model_req = req_id_to_old_request[req.py_request_id] + if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: + # This is a chunked prefill request and we have more prefill chunks + # to process. Defer adding draft tokens until the whole prompt is processed. + self.draft_seq_slot_manager.free_resources(req) + continue + target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) if req.state != LlmRequestState.GENERATION_COMPLETE and len( target_model_req.py_draft_tokens diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 0b093e3ad82..ffb8e33766a 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -14,21 +14,21 @@ @pytest.mark.parametrize( - "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model", + "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", [ - [True, "TRTLLM", True, False, False], - [False, "TRTLLM", True, False, False], - [True, "TRTLLM", True, True, False], - [False, "TRTLLM", True, True, False], - [True, "FLASHINFER", True, False, False], - [False, "FLASHINFER", True, False, False], - [False, "TRTLLM", False, True, True], - [True, "TRTLLM", False, True, True], + [True, "TRTLLM", True, False, False, False], + [False, "TRTLLM", True, False, False, False], + [True, "FLASHINFER", True, False, False, False], + [False, "FLASHINFER", True, False, False, False], + [False, "TRTLLM", False, True, True, False], + [True, "TRTLLM", False, True, True, False], + [True, "TRTLLM", True, False, True, True], + [True, "TRTLLM", True, False, False, True], ]) @pytest.mark.high_cuda_memory def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, disable_overlap_scheduler: bool, enable_block_reuse: bool, - use_one_model: bool): + use_one_model: bool, enable_chunked_prefill: bool): # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: @@ -59,7 +59,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, # that the draft model won't go above its max in warmup # in this test. max_seq_len=8192, + enable_chunked_prefill=enable_chunked_prefill, ) + if enable_chunked_prefill: + # Use a small max_num_tokens so that the chunked prefill path gets exercised. + llm_common_config['max_num_tokens'] = 64 spec_config = EagleDecodingConfig( max_draft_len=max_draft_len, @@ -71,7 +75,19 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, llm_spec = LLM(**llm_common_config, speculative_config=spec_config) # Acceptance rate tests - tok_ids = llm_spec.tokenizer.encode("The future of AI is") + if enable_chunked_prefill: + # Use a long prompt for chunked prefill tests. + prompts = [ + "The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and " + ] + tok_ids = llm_spec.tokenizer.encode(prompts[0]) + else: + prompts = [ + "The capital of France is", + "The president of the United States is", + ] + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + num_tokens = 0 num_drafted = 0 num_accepted = 0 @@ -88,10 +104,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, assert accept_rate > 0.15 # Output tests - prompts = [ - "The capital of France is", - "The president of the United States is", - ] sampling_params = SamplingParams(max_tokens=10, temperature=0) results_spec = llm_spec.generate(prompts, sampling_params) From 61d9701f307c61ae62ad2b0447babea595d9f502 Mon Sep 17 00:00:00 2001 From: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Date: Wed, 23 Jul 2025 12:01:34 -0700 Subject: [PATCH 2/2] Fix get_draft_model_prompt call Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/speculative/model_drafter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 6981f9d63ad..318cce8c736 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -203,8 +203,8 @@ def _prepare_draft_batch( # a prefill chunk on the last iteration. Now, we need to fill in the KV cache # for the draft model too. all_tokens = request.get_tokens()[0] - input_tokens = self.spec_config.get_draft_model_prompt( - all_tokens) + input_tokens = get_draft_model_prompt( + self.spec_config.spec_dec_mode, all_tokens) new_request = self._create_context_request( request, input_tokens)