Skip to content

Commit 73f5fcb

Browse files
committed
NIXL: re-work send timeout tracking on prefill side
In a prefill instance, we need to free KV blocks that have not been fetched after a timeout. See vllm-project#20139. In vllm-project#26012, we're trying to deal with corner cases involved with doing this request timeout tracking on the worker side. This PR proposes moving all of this to the scheduler side, hopefully making the logic simpler. Note the expiry timer is switched back to monotonic time because the timestamp is no longer sent across process boundaries. Signed-off-by: Mark McLoughlin <[email protected]>
1 parent bb6d430 commit 73f5fcb

File tree

1 file changed

+74
-70
lines changed

1 file changed

+74
-70
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
if TYPE_CHECKING:
4141
from vllm.attention.backends.abstract import AttentionMetadata
4242
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
43+
from vllm.v1.outputs import KVConnectorOutput
4344
from vllm.v1.request import Request
4445

4546
Transfer = tuple[int, float] # (xfer_handle, start_time)
@@ -107,8 +108,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
107108
def __init__(self):
108109
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
109110
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
110-
self.reqs_to_send: dict[ReqId, float] = {}
111-
self.reqs_in_batch: set[ReqId] = set()
112111

113112
def add_new_req(
114113
self,
@@ -195,6 +194,14 @@ def build_connector_meta(
195194
assert self.connector_scheduler is not None
196195
return self.connector_scheduler.build_connector_meta(scheduler_output)
197196

197+
def update_connector_output(
198+
self,
199+
connector_output: "KVConnectorOutput",
200+
):
201+
assert self.connector_scheduler is not None
202+
return self.connector_scheduler.update_connector_output(
203+
connector_output)
204+
198205
def request_finished(
199206
self,
200207
request: "Request",
@@ -280,9 +287,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
280287
# the scheduler. Used to make metadata passed to Worker.
281288
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
282289
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
283-
# Reqs to send and their expiration time
284-
self._reqs_need_send: dict[ReqId, float] = {}
285-
self._reqs_in_batch: set[ReqId] = set()
290+
291+
# Requests that need to be sent for remote decode, along with:
292+
# 1. an expiry time to avoid stranded KV blocks if they
293+
# are never fetched
294+
# 2. a consumer notification count - with heterogeneous TP, P
295+
# must wait for all assigned D TP workers to finish reading
296+
# before safely freeing the blocks.
297+
self._reqs_need_send: dict[ReqId, tuple[float, int]] = {}
286298

287299
def get_num_new_matched_tokens(
288300
self, request: "Request",
@@ -330,8 +342,6 @@ def update_state_after_alloc(self, request: "Request",
330342
if not params:
331343
return
332344

333-
if params.get("do_remote_decode"):
334-
self._reqs_in_batch.add(request.request_id)
335345
if self.use_host_buffer and params.get("do_remote_decode"):
336346
# NOTE: when accelerator is not directly supported by Nixl,
337347
# prefilled blocks need to be saved to host memory before transfer.
@@ -395,17 +405,56 @@ def build_connector_meta(
395405
save_to_host=True,
396406
)
397407

398-
meta.reqs_to_send = self._reqs_need_send
399-
meta.reqs_in_batch = self._reqs_in_batch
400-
401408
# Clear the list once workers start the transfers
402409
self._reqs_need_recv.clear()
403410
self._reqs_need_save.clear()
404-
self._reqs_in_batch = set()
405-
self._reqs_need_send = {}
406411

407412
return meta
408413

414+
def update_connector_output(
415+
self,
416+
connector_output: "KVConnectorOutput",
417+
):
418+
finished_sending: set[str] = set()
419+
420+
# Blocks sent - remove expiry timeout
421+
for notif in (connector_output.finished_sending or ()):
422+
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
423+
# Sent notifications received after we already timed out
424+
if req_id not in self._reqs_need_send:
425+
logger.debug(
426+
"Already finished or expired KV transfer for request %s",
427+
req_id)
428+
continue
429+
430+
# Wait all consumers (D) to be done reading before freeing.
431+
count = self._reqs_need_send[req_id][1] + 1
432+
if count < int(tp_ratio):
433+
self._reqs_need_send[req_id] = (
434+
self._reqs_need_send[req_id][0], count)
435+
continue
436+
logger.debug(
437+
"KV transfer finished for request %s after "
438+
"retrieval by %d decode worker(s).", req_id, count)
439+
del self._reqs_need_send[req_id]
440+
finished_sending.add(req_id)
441+
442+
# Mark as finished if the expiry timeout has passed
443+
now = time.monotonic()
444+
while self._reqs_need_send:
445+
req_id, (expires, count) = next(iter(self._reqs_need_send.items()))
446+
# Insertion-ordered dict; oldest first so we can exit early.
447+
if now < expires:
448+
break
449+
logger.warning(
450+
"Releasing expired KV blocks for request %s which were "
451+
"retrieved by %d decode worker(s) within %d seconds.", req_id,
452+
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
453+
del self._reqs_need_send[req_id]
454+
finished_sending.add(req_id)
455+
456+
connector_output.finished_sending = finished_sending
457+
409458
def request_finished(
410459
self,
411460
request: "Request",
@@ -435,8 +484,15 @@ def request_finished(
435484
params["do_remote_prefill"] = False
436485
return False, None
437486

438-
if (not params.get("do_remote_decode")
439-
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
487+
if not params.get("do_remote_decode"):
488+
return False, None
489+
490+
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
491+
if request.request_id in self._reqs_need_send:
492+
# Request aborted after we delayed freeing the blocks
493+
logger.debug("Deleting KV transfer timeout for request %s",
494+
request.request_id)
495+
del self._reqs_need_send[request.request_id]
440496
return False, None
441497

442498
# TODO: check whether block_ids actually ever be 0. If not we could
@@ -445,8 +501,8 @@ def request_finished(
445501

446502
if delay_free_blocks:
447503
# Prefill request on remote. It will be read from D upon completion
448-
self._reqs_need_send[request.request_id] = time.perf_counter(
449-
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
504+
expiry = time.monotonic() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
505+
self._reqs_need_send[request.request_id] = (expiry, 0)
450506

451507
return delay_free_blocks, dict(
452508
do_remote_prefill=True,
@@ -559,10 +615,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
559615
# [req_id -> list[handle]]
560616
self._recving_metadata: dict[ReqId, ReqMeta] = {}
561617
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
562-
# Track the expiration time of requests that are waiting to be sent.
563-
self._reqs_to_send: dict[ReqId, float] = {}
564-
# Set of requests that have been part of a batch, regardless of status.
565-
self._reqs_to_process: set[ReqId] = set()
566618

567619
# Background thread for handling new handshake requests.
568620
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@@ -601,9 +653,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
601653
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
602654

603655
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
604-
# With heterogeneous TP, P must wait for all assigned D TP workers to
605-
# finish reading before safely freeing the blocks.
606-
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
607656
self.xfer_stats = NixlKVConnectorStats()
608657

609658
@staticmethod
@@ -1113,22 +1162,6 @@ def get_finished(self) -> tuple[set[str], set[str]]:
11131162
assert meta, f"{req_id} not found in recving_metadata list"
11141163
self.sync_recved_kv_to_device(req_id, meta)
11151164

1116-
# Handle timeout to avoid stranding blocks on remote.
1117-
now = time.perf_counter()
1118-
while self._reqs_to_send:
1119-
req_id, expires = next(iter(self._reqs_to_send.items()))
1120-
# Sorted dict, oldest requests are put first so we can exit early.
1121-
if now < expires:
1122-
break
1123-
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
1124-
logger.warning(
1125-
"Releasing expired KV blocks for request %s which were "
1126-
"retrieved by %d decode worker(s) within %d seconds.", req_id,
1127-
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
1128-
self._reqs_to_process.remove(req_id)
1129-
del self._reqs_to_send[req_id]
1130-
done_sending.add(req_id)
1131-
11321165
return done_sending, done_recving
11331166

11341167
def _get_new_notifs(self) -> set[str]:
@@ -1140,23 +1173,8 @@ def _get_new_notifs(self) -> set[str]:
11401173
notified_req_ids: set[str] = set()
11411174
for notifs in self.nixl_wrapper.get_new_notifs().values():
11421175
for notif in notifs:
1143-
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
1144-
if (req_id not in self._reqs_to_send
1145-
and req_id not in self._reqs_to_process):
1146-
logger.error(
1147-
"Potentially invalid KV blocks for "
1148-
"unrecognized request %s were retrieved by "
1149-
"a decode worker. They may have expired.", req_id)
1150-
continue
1151-
1152-
self.consumer_notification_counts_by_req[req_id] += 1
1153-
# Wait all consumers (D) to be done reading before freeing.
1154-
if self.consumer_notification_counts_by_req[req_id] == int(
1155-
tp_ratio):
1156-
notified_req_ids.add(req_id)
1157-
del self.consumer_notification_counts_by_req[req_id]
1158-
self._reqs_to_process.remove(req_id)
1159-
self._reqs_to_send.pop(req_id, None)
1176+
# Note - this is in req_id:tp_ratio format
1177+
notified_req_ids.add(notif)
11601178
return notified_req_ids
11611179

11621180
def _pop_done_transfers(
@@ -1217,20 +1235,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
12171235
while not self._ready_requests.empty():
12181236
self._read_blocks_for_req(*self._ready_requests.get_nowait())
12191237

1220-
# Keep around the requests that have been part of a batch. This is
1221-
# needed because async scheduling pushes the misalignment between the
1222-
# moment in which requests expiration is set (P side) and the moment in
1223-
# which blocks are read from D. As P can now more easily lag behind D
1224-
# while processing the next batch, we make sure to only set an
1225-
# expiration for requests that have not been read from D yet.
1226-
for req_id in metadata.reqs_in_batch:
1227-
self._reqs_to_process.add(req_id)
1228-
1229-
# Add to requests that are waiting to be read and track expiration.
1230-
for req_id, expiration_time in metadata.reqs_to_send.items():
1231-
if req_id in self._reqs_to_process:
1232-
self._reqs_to_send[req_id] = expiration_time
1233-
12341238
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
12351239
logger.debug(
12361240
"Remote agent %s available, calling _read_blocks for req %s",

0 commit comments

Comments
 (0)