22
33from __future__ import annotations
44
5+ import itertools
56import time
67from collections import deque
78from collections .abc import Iterable
@@ -144,7 +145,7 @@ def schedule(self) -> SchedulerOutput:
144145 # uses structured decoding.
145146 structured_output_request_ids : dict [str , int ] = {}
146147
147- req_to_new_block_ids : dict [str , list [int ]] = {}
148+ req_to_new_block_ids : dict [str , list [list [ int ] ]] = {}
148149 num_scheduled_tokens : dict [str , int ] = {}
149150 token_budget = self .max_num_scheduled_tokens
150151 # Encoder-related.
@@ -165,7 +166,8 @@ def schedule(self) -> SchedulerOutput:
165166 req_index += 1
166167 continue
167168
168- num_new_tokens = (request .num_tokens_with_spec -
169+ num_draft_tokens = len (request .draft_token_ids )
170+ num_new_tokens = (request .num_tokens + num_draft_tokens -
169171 request .num_computed_tokens )
170172 if (0 < self .scheduler_config .long_prefill_token_threshold <
171173 num_new_tokens ):
@@ -196,7 +198,8 @@ def schedule(self) -> SchedulerOutput:
196198 while True :
197199 new_blocks = self .kv_cache_manager .allocate_slots (
198200 request ,
199- num_new_tokens ,
201+ num_new_tokens - num_draft_tokens ,
202+ num_draft_tokens = num_draft_tokens ,
200203 num_lookahead_tokens = self .num_lookahead_tokens )
201204 if new_blocks is None :
202205 # The request cannot be scheduled.
@@ -233,7 +236,7 @@ def schedule(self) -> SchedulerOutput:
233236 # cycle to fill in the bitmask, which could be a big no-op.
234237 structured_output_request_ids [request .request_id ] = req_index
235238 req_to_new_block_ids [request .request_id ] = [
236- b .block_id for b in new_blocks
239+ [ b .block_id for b in blocks ] for blocks in new_blocks
237240 ]
238241 num_scheduled_tokens [request .request_id ] = num_new_tokens
239242 token_budget -= num_new_tokens
@@ -330,7 +333,11 @@ def schedule(self) -> SchedulerOutput:
330333 new_encoder_budget = encoder_budget
331334
332335 new_blocks = self .kv_cache_manager .allocate_slots (
333- request , num_new_tokens , num_computed_tokens , computed_blocks )
336+ request ,
337+ num_new_tokens ,
338+ new_computed_tokens = num_computed_tokens ,
339+ new_computed_blocks = computed_blocks ,
340+ num_lookahead_tokens = self .num_lookahead_tokens )
334341 if new_blocks is None :
335342 # The request cannot be scheduled.
336343 break
@@ -355,9 +362,9 @@ def schedule(self) -> SchedulerOutput:
355362
356363 if self .lora_config and request .lora_request :
357364 scheduled_loras .add (request .lora_request .lora_int_id )
358- req_to_new_block_ids [request .request_id ] = [
359- b .block_id for b in computed_blocks + new_blocks
360- ]
365+ req_to_new_block_ids [request .request_id ] = [[
366+ b .block_id for b in itertools . chain ( b1 , b2 )
367+ ] for b1 , b2 in zip ( computed_blocks , new_blocks )]
361368 num_scheduled_tokens [request .request_id ] = num_new_tokens
362369 token_budget -= num_new_tokens
363370 request .status = RequestStatus .RUNNING
@@ -463,7 +470,7 @@ def _make_cached_request_data(
463470 request : Request ,
464471 num_scheduled_tokens : int ,
465472 num_scheduled_spec_tokens : int ,
466- new_block_ids : list [int ],
473+ new_block_ids : list [list [ int ] ],
467474 resumed_from_preemption : bool ,
468475 ) -> CachedRequestData :
469476 # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
0 commit comments