Skip to content

Commit c3792ef

Browse files
committed
Remove some useless code in ModelDrafter
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent feb1b43 commit c3792ef

File tree

1 file changed

+15
-32
lines changed

1 file changed

+15
-32
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,6 @@ def __init__(
4747
# Sampling
4848
self.sampler = sampler
4949

50-
def _should_process_request(self, request: LlmRequest) -> bool:
51-
"""Check if request should be processed for drafting."""
52-
return request.py_draft_pages_allocated > 0 # type: ignore
53-
54-
def _exceeds_max_sequence_length(self, request: LlmRequest) -> bool:
55-
"""Check if the request exceeds maximum sequence length for drafting."""
56-
return request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len
57-
5850
def _create_draft_request(self, request_id: int, max_new_tokens: int,
5951
input_tokens: Optional[List],
6052
sampling_config: SamplingConfig,
@@ -81,10 +73,6 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
8173

8274
return num_draft_tokens, num_accepted_tokens
8375

84-
def _get_draft_model_input(self, request: LlmRequest) -> Any:
85-
"""Get input tokens for draft model."""
86-
return self.spec_config.get_draft_model_prompt(request.get_tokens()[0])
87-
8876
def _create_context_request(self, request: LlmRequest,
8977
input_tokens: Any) -> LlmRequest:
9078
"""Create a context request for first-time drafting."""
@@ -116,10 +104,6 @@ def _create_chunked_context_request(self, request: LlmRequest,
116104
request.sampling_config,
117105
request.return_perf_metrics)
118106
new_request.context_chunk_size = num_accepted_tokens + 1
119-
new_request.context_current_position = len(
120-
input_tokens) - num_accepted_tokens - 1
121-
# Note: Original code has duplicate assignment (appears to be a bug, but keeping it)
122-
new_request.context_chunk_size = num_accepted_tokens + 1
123107
new_request.context_current_position = len(
124108
input_tokens) - num_accepted_tokens - 1
125109
return new_request
@@ -129,7 +113,8 @@ def _create_draft_request_for_request(
129113
"""Create a draft request based on the original request state."""
130114
num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens(
131115
request)
132-
input_tokens = self._get_draft_model_input(request)
116+
input_tokens = self.spec_config.get_draft_model_prompt(
117+
request.get_tokens()[0])
133118

134119
# First time seeing this request - context request
135120
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
@@ -184,11 +169,18 @@ def _prepare_draft_batch(
184169
draft_batch = ScheduledRequests()
185170

186171
for request in scheduled_requests.generation_requests:
187-
if not self._should_process_request(request):
172+
if request.py_draft_pages_allocated == 0:
173+
# No space for draft tokens
188174
continue
189175

190-
# Stop drafting when we hit the max seqlen
191-
if self._exceeds_max_sequence_length(request):
176+
# Stop drafting when we hit the max seqlen. We still need dummy draft
177+
# tokens attached to the requests to make sure everything works properly
178+
# with CUDA graph. These dummy tokens are already added by
179+
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
180+
# that we want to do spec decoding, so no need to do anything else here.
181+
# This makes the perf for this case suboptimal, but that's OK - this is
182+
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
183+
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
192184
continue
193185

194186
draft_request = self._create_draft_request_for_request(request)
@@ -255,17 +247,8 @@ def _update_request_states(self,
255247

256248
def _update_requests(self, sample_state: SampleState) -> None:
257249
"""Update requests with sample state."""
258-
try:
259-
if self.sampler is not None:
260-
self.sampler.update_requests(sample_state)
261-
except Exception as e:
262-
logger.error(f"Error updating requests: {str(e)}")
263-
264-
def _handle_errors(self, error_msg: str) -> None:
265-
"""Handle errors during draft token generation."""
266-
logger.error(f"Draft token generation error: {error_msg}")
267-
# For now, just log the error. In a full implementation, this could
268-
# clean up resources, notify other components, etc.
250+
if self.sampler is not None:
251+
self.sampler.update_requests(sample_state)
269252

270253
def _process_decoded_tokens(
271254
self, draft_batch: ScheduledRequests,
@@ -277,7 +260,7 @@ def _process_decoded_tokens(
277260
target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
278261
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
279262
target_model_req.py_draft_tokens
280-
) < target_model_req.py_draft_pages_allocated: # type: ignore
263+
) < target_model_req.py_draft_pages_allocated:
281264
new_requests.append(req)
282265
else:
283266
self.draft_seq_slot_manager.free_resources(req)

0 commit comments

Comments
 (0)