4040if TYPE_CHECKING :
4141 from vllm .attention .backends .abstract import AttentionMetadata
4242 from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
43+ from vllm .v1 .outputs import KVConnectorOutput
4344 from vllm .v1 .request import Request
4445
4546Transfer = tuple [int , float ] # (xfer_handle, start_time)
@@ -107,8 +108,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
107108 def __init__ (self ):
108109 self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
109110 self .reqs_to_save : dict [ReqId , ReqMeta ] = {}
110- self .reqs_to_send : dict [ReqId , float ] = {}
111- self .reqs_in_batch : set [ReqId ] = set ()
112111
113112 def add_new_req (
114113 self ,
@@ -195,6 +194,14 @@ def build_connector_meta(
195194 assert self .connector_scheduler is not None
196195 return self .connector_scheduler .build_connector_meta (scheduler_output )
197196
197+ def update_connector_output (
198+ self ,
199+ connector_output : "KVConnectorOutput" ,
200+ ):
201+ assert self .connector_scheduler is not None
202+ return self .connector_scheduler .update_connector_output (
203+ connector_output )
204+
198205 def request_finished (
199206 self ,
200207 request : "Request" ,
@@ -280,9 +287,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
280287 # the scheduler. Used to make metadata passed to Worker.
281288 self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
282289 self ._reqs_need_save : dict [ReqId , tuple [Request , list [int ]]] = {}
283- # Reqs to send and their expiration time
284- self ._reqs_need_send : dict [ReqId , float ] = {}
285- self ._reqs_in_batch : set [ReqId ] = set ()
290+
291+ # Requests that need to be sent for remote decode, along with:
292+ # 1. an expiry time to avoid stranded KV blocks if they
293+ # are never fetched
294+ # 2. a consumer notification count - with heterogeneous TP, P
295+ # must wait for all assigned D TP workers to finish reading
296+ # before safely freeing the blocks.
297+ self ._reqs_need_send : dict [ReqId , tuple [float , int ]] = {}
286298
287299 def get_num_new_matched_tokens (
288300 self , request : "Request" ,
@@ -330,8 +342,6 @@ def update_state_after_alloc(self, request: "Request",
330342 if not params :
331343 return
332344
333- if params .get ("do_remote_decode" ):
334- self ._reqs_in_batch .add (request .request_id )
335345 if self .use_host_buffer and params .get ("do_remote_decode" ):
336346 # NOTE: when accelerator is not directly supported by Nixl,
337347 # prefilled blocks need to be saved to host memory before transfer.
@@ -395,17 +405,56 @@ def build_connector_meta(
395405 save_to_host = True ,
396406 )
397407
398- meta .reqs_to_send = self ._reqs_need_send
399- meta .reqs_in_batch = self ._reqs_in_batch
400-
401408 # Clear the list once workers start the transfers
402409 self ._reqs_need_recv .clear ()
403410 self ._reqs_need_save .clear ()
404- self ._reqs_in_batch = set ()
405- self ._reqs_need_send = {}
406411
407412 return meta
408413
414+ def update_connector_output (
415+ self ,
416+ connector_output : "KVConnectorOutput" ,
417+ ):
418+ finished_sending : set [str ] = set ()
419+
420+ # Blocks sent - remove expiry timeout
421+ for notif in (connector_output .finished_sending or ()):
422+ req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
423+ # Sent notifications received after we already timed out
424+ if req_id not in self ._reqs_need_send :
425+ logger .debug (
426+ "Already finished or expired KV transfer for request %s" ,
427+ req_id )
428+ continue
429+
430+ # Wait all consumers (D) to be done reading before freeing.
431+ count = self ._reqs_need_send [req_id ][1 ] + 1
432+ if count < int (tp_ratio ):
433+ self ._reqs_need_send [req_id ] = (
434+ self ._reqs_need_send [req_id ][0 ], count )
435+ continue
436+ logger .debug (
437+ "KV transfer finished for request %s after "
438+ "retrieval by %d decode worker(s)." , req_id , count )
439+ del self ._reqs_need_send [req_id ]
440+ finished_sending .add (req_id )
441+
442+ # Mark as finished if the expiry timeout has passed
443+ now = time .monotonic ()
444+ while self ._reqs_need_send :
445+ req_id , (expires , count ) = next (iter (self ._reqs_need_send .items ()))
446+ # Insertion-ordered dict; oldest first so we can exit early.
447+ if now < expires :
448+ break
449+ logger .warning (
450+ "Releasing expired KV blocks for request %s which were "
451+ "retrieved by %d decode worker(s) within %d seconds." , req_id ,
452+ count , envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT )
453+ del self ._reqs_need_send [req_id ]
454+ finished_sending .add (req_id )
455+
456+ connector_output .finished_sending = finished_sending
457+
409458 def request_finished (
410459 self ,
411460 request : "Request" ,
@@ -435,8 +484,15 @@ def request_finished(
435484 params ["do_remote_prefill" ] = False
436485 return False , None
437486
438- if (not params .get ("do_remote_decode" )
439- or request .status != RequestStatus .FINISHED_LENGTH_CAPPED ):
487+ if not params .get ("do_remote_decode" ):
488+ return False , None
489+
490+ if request .status != RequestStatus .FINISHED_LENGTH_CAPPED :
491+ if request .request_id in self ._reqs_need_send :
492+ # Request aborted after we delayed freeing the blocks
493+ logger .debug ("Deleting KV transfer timeout for request %s" ,
494+ request .request_id )
495+ del self ._reqs_need_send [request .request_id ]
440496 return False , None
441497
442498 # TODO: check whether block_ids actually ever be 0. If not we could
@@ -445,8 +501,8 @@ def request_finished(
445501
446502 if delay_free_blocks :
447503 # Prefill request on remote. It will be read from D upon completion
448- self . _reqs_need_send [ request . request_id ] = time .perf_counter (
449- ) + envs . VLLM_NIXL_ABORT_REQUEST_TIMEOUT
504+ expiry = time .monotonic () + envs . VLLM_NIXL_ABORT_REQUEST_TIMEOUT
505+ self . _reqs_need_send [ request . request_id ] = ( expiry , 0 )
450506
451507 return delay_free_blocks , dict (
452508 do_remote_prefill = True ,
@@ -559,10 +615,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
559615 # [req_id -> list[handle]]
560616 self ._recving_metadata : dict [ReqId , ReqMeta ] = {}
561617 self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
562- # Track the expiration time of requests that are waiting to be sent.
563- self ._reqs_to_send : dict [ReqId , float ] = {}
564- # Set of requests that have been part of a batch, regardless of status.
565- self ._reqs_to_process : set [ReqId ] = set ()
566618
567619 # Background thread for handling new handshake requests.
568620 self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
@@ -601,9 +653,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
601653 logger .debug ("Detected kv cache layout %s" , self .kv_cache_layout )
602654
603655 self ._tp_size : dict [EngineId , int ] = {self .engine_id : self .world_size }
604- # With heterogeneous TP, P must wait for all assigned D TP workers to
605- # finish reading before safely freeing the blocks.
606- self .consumer_notification_counts_by_req = defaultdict [ReqId , int ](int )
607656 self .xfer_stats = NixlKVConnectorStats ()
608657
609658 @staticmethod
@@ -1113,22 +1162,6 @@ def get_finished(self) -> tuple[set[str], set[str]]:
11131162 assert meta , f"{ req_id } not found in recving_metadata list"
11141163 self .sync_recved_kv_to_device (req_id , meta )
11151164
1116- # Handle timeout to avoid stranding blocks on remote.
1117- now = time .perf_counter ()
1118- while self ._reqs_to_send :
1119- req_id , expires = next (iter (self ._reqs_to_send .items ()))
1120- # Sorted dict, oldest requests are put first so we can exit early.
1121- if now < expires :
1122- break
1123- count = self .consumer_notification_counts_by_req .pop (req_id , 0 )
1124- logger .warning (
1125- "Releasing expired KV blocks for request %s which were "
1126- "retrieved by %d decode worker(s) within %d seconds." , req_id ,
1127- count , envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT )
1128- self ._reqs_to_process .remove (req_id )
1129- del self ._reqs_to_send [req_id ]
1130- done_sending .add (req_id )
1131-
11321165 return done_sending , done_recving
11331166
11341167 def _get_new_notifs (self ) -> set [str ]:
@@ -1140,23 +1173,8 @@ def _get_new_notifs(self) -> set[str]:
11401173 notified_req_ids : set [str ] = set ()
11411174 for notifs in self .nixl_wrapper .get_new_notifs ().values ():
11421175 for notif in notifs :
1143- req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
1144- if (req_id not in self ._reqs_to_send
1145- and req_id not in self ._reqs_to_process ):
1146- logger .error (
1147- "Potentially invalid KV blocks for "
1148- "unrecognized request %s were retrieved by "
1149- "a decode worker. They may have expired." , req_id )
1150- continue
1151-
1152- self .consumer_notification_counts_by_req [req_id ] += 1
1153- # Wait all consumers (D) to be done reading before freeing.
1154- if self .consumer_notification_counts_by_req [req_id ] == int (
1155- tp_ratio ):
1156- notified_req_ids .add (req_id )
1157- del self .consumer_notification_counts_by_req [req_id ]
1158- self ._reqs_to_process .remove (req_id )
1159- self ._reqs_to_send .pop (req_id , None )
1176+ # Note - this is in req_id:tp_ratio format
1177+ notified_req_ids .add (notif )
11601178 return notified_req_ids
11611179
11621180 def _pop_done_transfers (
@@ -1217,20 +1235,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
12171235 while not self ._ready_requests .empty ():
12181236 self ._read_blocks_for_req (* self ._ready_requests .get_nowait ())
12191237
1220- # Keep around the requests that have been part of a batch. This is
1221- # needed because async scheduling pushes the misalignment between the
1222- # moment in which requests expiration is set (P side) and the moment in
1223- # which blocks are read from D. As P can now more easily lag behind D
1224- # while processing the next batch, we make sure to only set an
1225- # expiration for requests that have not been read from D yet.
1226- for req_id in metadata .reqs_in_batch :
1227- self ._reqs_to_process .add (req_id )
1228-
1229- # Add to requests that are waiting to be read and track expiration.
1230- for req_id , expiration_time in metadata .reqs_to_send .items ():
1231- if req_id in self ._reqs_to_process :
1232- self ._reqs_to_send [req_id ] = expiration_time
1233-
12341238 def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
12351239 logger .debug (
12361240 "Remote agent %s available, calling _read_blocks for req %s" ,
0 commit comments