Skip to content

Commit 92b4a83

Browse files
committed
[feat] Support chunked prefill on spec decode 2 model
Signed-off-by: Mike Iovine <[email protected]>
1 parent fd6ce7f commit 92b4a83

File tree

4 files changed

+70
-20
lines changed

4 files changed

+70
-20
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def __init__(
303303
self.py_batch_idx = None
304304
self.py_rewind_len = 0
305305
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
306+
self.py_last_context_chunk = (None, None)
306307
self.py_last_draft_tokens = None
307308
self.py_num_accepted_draft_tokens = 0
308309
self.py_decoding_iter = 0

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,10 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
17291729

17301730
for request in scheduled_requests.context_requests:
17311731
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
1732+
request.py_last_context_chunk = (
1733+
request.context_current_position,
1734+
request.context_current_position +
1735+
request.context_chunk_size)
17321736
request.move_to_next_context_chunk()
17331737
if request.context_remaining_length == 0:
17341738
request.state = LlmRequestState.GENERATION_IN_PROGRESS

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
7676
def _create_context_request(self, request: LlmRequest,
7777
input_tokens: Any) -> LlmRequest:
7878
"""Create a context request for first-time drafting."""
79-
return self._create_draft_request(request.py_request_id,
80-
request.py_max_new_tokens,
81-
input_tokens, request.sampling_config,
82-
request.return_perf_metrics)
79+
new_request = self._create_draft_request(request.py_request_id,
80+
request.py_max_new_tokens,
81+
input_tokens,
82+
request.sampling_config,
83+
request.return_perf_metrics)
84+
85+
begin_compute, end_compute = request.py_last_context_chunk
86+
new_request.context_current_position = begin_compute
87+
new_request.context_chunk_size = end_compute - begin_compute
88+
return new_request
8389

8490
def _create_generation_request(self, request: LlmRequest,
8591
input_tokens: Any) -> LlmRequest:
@@ -94,10 +100,13 @@ def _create_generation_request(self, request: LlmRequest,
94100
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
95101
return new_request
96102

97-
def _create_chunked_context_request(self, request: LlmRequest,
103+
def _create_accepted_tokens_request(self, request: LlmRequest,
98104
input_tokens: Any,
99105
num_accepted_tokens: int) -> LlmRequest:
100-
"""Create a chunked context request when some tokens were accepted."""
106+
"""
107+
Create a chunked context request for accepted tokens.
108+
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
109+
"""
101110
new_request = self._create_draft_request(request.py_request_id,
102111
request.py_max_new_tokens,
103112
input_tokens,
@@ -130,7 +139,7 @@ def _create_draft_request_for_request(
130139

131140
# Tokens accepted - chunked context request
132141
else:
133-
return self._create_chunked_context_request(request, input_tokens,
142+
return self._create_accepted_tokens_request(request, input_tokens,
134143
num_accepted_tokens)
135144

136145
def _add_to_draft_batch(self, draft_batch: ScheduledRequests,
@@ -168,6 +177,22 @@ def _prepare_draft_batch(
168177
try:
169178
draft_batch = ScheduledRequests()
170179

180+
for request in scheduled_requests.context_requests:
181+
if request.is_first_context_chunk:
182+
# Ignore requests which still need to be processed by the target model.
183+
continue
184+
185+
# We hit this path if we're doing chunked prefill. The target model processed
186+
# a prefill chunk on the last iteration. Now, we need to fill in the KV cache
187+
# for the draft model too.
188+
all_tokens = request.get_tokens()[0]
189+
input_tokens = self.spec_config.get_draft_model_prompt(
190+
all_tokens)
191+
192+
new_request = self._create_context_request(
193+
request, input_tokens)
194+
self._add_to_draft_batch(draft_batch, new_request, request)
195+
171196
for request in scheduled_requests.generation_requests:
172197
if request.py_draft_pages_allocated == 0:
173198
# No space for draft tokens
@@ -257,6 +282,12 @@ def _process_decoded_tokens(
257282
new_requests = []
258283
for req in draft_batch.all_requests():
259284
target_model_req = req_id_to_old_request[req.py_request_id]
285+
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
286+
# This is a chunked prefill request and we have more prefill chunks
287+
# to process. Defer adding draft tokens until the whole prompt is processed.
288+
self.draft_seq_slot_manager.free_resources(req)
289+
continue
290+
260291
target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
261292
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
262293
target_model_req.py_draft_tokens

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414

1515

1616
@pytest.mark.parametrize(
17-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model",
17+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
1818
[
19-
[True, "TRTLLM", True, False, False],
20-
[False, "TRTLLM", True, False, False],
21-
[True, "FLASHINFER", True, False, False],
22-
[False, "FLASHINFER", True, False, False],
23-
[False, "TRTLLM", False, True, True],
24-
[True, "TRTLLM", False, True, True],
19+
[True, "TRTLLM", True, False, False, False],
20+
[False, "TRTLLM", True, False, False, False],
21+
[True, "FLASHINFER", True, False, False, False],
22+
[False, "FLASHINFER", True, False, False, False],
23+
[False, "TRTLLM", False, True, True, False],
24+
[True, "TRTLLM", False, True, True, False],
25+
[True, "TRTLLM", True, False, True, True],
26+
[True, "TRTLLM", True, False, False, True],
2527
])
2628
@pytest.mark.high_cuda_memory
2729
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
2830
disable_overlap_scheduler: bool, enable_block_reuse: bool,
29-
use_one_model: bool):
31+
use_one_model: bool, enable_chunked_prefill: bool):
3032
# Eagle3 one model works with overlap scheduler and block reuse.
3133
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
3234
if total_mem_gb < 35:
@@ -57,7 +59,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5759
# that the draft model won't go above its max in warmup
5860
# in this test.
5961
max_seq_len=8192,
62+
enable_chunked_prefill=enable_chunked_prefill,
6063
)
64+
if enable_chunked_prefill:
65+
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
66+
llm_common_config['max_num_tokens'] = 64
6167

6268
spec_config = EagleDecodingConfig(
6369
max_draft_len=max_draft_len,
@@ -69,7 +75,19 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
6975
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
7076

7177
# Acceptance rate tests
72-
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
78+
if enable_chunked_prefill:
79+
# Use a long prompt for chunked prefill tests.
80+
prompts = [
81+
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
82+
]
83+
tok_ids = llm_spec.tokenizer.encode(prompts[0])
84+
else:
85+
prompts = [
86+
"The capital of France is",
87+
"The president of the United States is",
88+
]
89+
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
90+
7391
num_tokens = 0
7492
num_drafted = 0
7593
num_accepted = 0
@@ -86,10 +104,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
86104
assert accept_rate > 0.15
87105

88106
# Output tests
89-
prompts = [
90-
"The capital of France is",
91-
"The president of the United States is",
92-
]
93107
sampling_params = SamplingParams(max_tokens=10, temperature=0)
94108

95109
results_spec = llm_spec.generate(prompts, sampling_params)

0 commit comments

Comments
 (0)