@@ -47,14 +47,6 @@ def __init__(
4747 # Sampling
4848 self .sampler = sampler
4949
50- def _should_process_request (self , request : LlmRequest ) -> bool :
51- """Check if request should be processed for drafting."""
52- return request .py_draft_pages_allocated > 0 # type: ignore
53-
54- def _exceeds_max_sequence_length (self , request : LlmRequest ) -> bool :
55- """Check if the request exceeds maximum sequence length for drafting."""
56- return request .max_beam_num_tokens - 1 >= self .draft_model_engine .max_seq_len
57-
5850 def _create_draft_request (self , request_id : int , max_new_tokens : int ,
5951 input_tokens : Optional [List ],
6052 sampling_config : SamplingConfig ,
@@ -81,10 +73,6 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
8173
8274 return num_draft_tokens , num_accepted_tokens
8375
84- def _get_draft_model_input (self , request : LlmRequest ) -> Any :
85- """Get input tokens for draft model."""
86- return self .spec_config .get_draft_model_prompt (request .get_tokens ()[0 ])
87-
8876 def _create_context_request (self , request : LlmRequest ,
8977 input_tokens : Any ) -> LlmRequest :
9078 """Create a context request for first-time drafting."""
@@ -116,10 +104,6 @@ def _create_chunked_context_request(self, request: LlmRequest,
116104 request .sampling_config ,
117105 request .return_perf_metrics )
118106 new_request .context_chunk_size = num_accepted_tokens + 1
119- new_request .context_current_position = len (
120- input_tokens ) - num_accepted_tokens - 1
121- # Note: Original code has duplicate assignment (appears to be a bug, but keeping it)
122- new_request .context_chunk_size = num_accepted_tokens + 1
123107 new_request .context_current_position = len (
124108 input_tokens ) - num_accepted_tokens - 1
125109 return new_request
@@ -129,7 +113,8 @@ def _create_draft_request_for_request(
129113 """Create a draft request based on the original request state."""
130114 num_draft_tokens , num_accepted_tokens = self ._initialize_draft_tokens (
131115 request )
132- input_tokens = self ._get_draft_model_input (request )
116+ input_tokens = self .spec_config .get_draft_model_prompt (
117+ request .get_tokens ()[0 ])
133118
134119 # First time seeing this request - context request
135120 if request .max_beam_num_tokens - 1 == request .py_prompt_len :
@@ -184,11 +169,18 @@ def _prepare_draft_batch(
184169 draft_batch = ScheduledRequests ()
185170
186171 for request in scheduled_requests .generation_requests :
187- if not self ._should_process_request (request ):
172+ if request .py_draft_pages_allocated == 0 :
173+ # No space for draft tokens
188174 continue
189175
190- # Stop drafting when we hit the max seqlen
191- if self ._exceeds_max_sequence_length (request ):
176+ # Stop drafting when we hit the max seqlen. We still need dummy draft
177+ # tokens attached to the requests to make sure everything works properly
178+ # with CUDA graph. These dummy tokens are already added by
179+ # _prepare_draft_requests to make the KV cache/scheduler aware of the fact
180+ # that we want to do spec decoding, so no need to do anything else here.
181+ # This makes the perf for this case suboptimal, but that's OK - this is
182+ # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
183+ if request .max_beam_num_tokens - 1 >= self .draft_model_engine .max_seq_len :
192184 continue
193185
194186 draft_request = self ._create_draft_request_for_request (request )
@@ -255,17 +247,8 @@ def _update_request_states(self,
255247
256248 def _update_requests (self , sample_state : SampleState ) -> None :
257249 """Update requests with sample state."""
258- try :
259- if self .sampler is not None :
260- self .sampler .update_requests (sample_state )
261- except Exception as e :
262- logger .error (f"Error updating requests: { str (e )} " )
263-
264- def _handle_errors (self , error_msg : str ) -> None :
265- """Handle errors during draft token generation."""
266- logger .error (f"Draft token generation error: { error_msg } " )
267- # For now, just log the error. In a full implementation, this could
268- # clean up resources, notify other components, etc.
250+ if self .sampler is not None :
251+ self .sampler .update_requests (sample_state )
269252
270253 def _process_decoded_tokens (
271254 self , draft_batch : ScheduledRequests ,
@@ -277,7 +260,7 @@ def _process_decoded_tokens(
277260 target_model_req .py_draft_tokens .append (req .get_last_tokens (0 ))
278261 if req .state != LlmRequestState .GENERATION_COMPLETE and len (
279262 target_model_req .py_draft_tokens
280- ) < target_model_req .py_draft_pages_allocated : # type: ignore
263+ ) < target_model_req .py_draft_pages_allocated :
281264 new_requests .append (req )
282265 else :
283266 self .draft_seq_slot_manager .free_resources (req )
0 commit comments