From 2a7e010cfb1394adafcce868810927921a50305b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 18 Jun 2025 14:25:43 +0000 Subject: [PATCH 01/14] abort request timeout env var Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 13 +++++++++++++ vllm/envs.py | 10 +++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 56ae1acf8571..3159e8024107 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -826,6 +826,19 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + # Handle timeout + # now = time.perf_counter() + # for req_id, finish_time in self._reqs_to_send.items(): + # if finish_time == -1: + # # Request just finished, start timeout. + # self._reqs_to_send[req_id] = now + # elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: + # # Timeout exceed, abort request and clear. + # aborted_req_ids.add(req_id) + # if req_id in self._done_sending_count: + # self._done_sending_count[req_id] += self.world_size + # del self._reqs_to_send[req_id] + if self.world_size == 1: return done_sending, done_recving diff --git a/vllm/envs.py b/vllm/envs.py index 0cc6792d72bb..ec6a4896774f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -138,6 +138,7 @@ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 def get_default_cache_root(): @@ -953,7 +954,14 @@ def get_vllm_port() -> Optional[int]: # generations on machines < 100 for compressed-tensors # models "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + + # Time (in seconds) after which the KV cache on the producer side is + # automatically cleared if no READ notification is received from the + # consumer. This is only applicable when using NixlConnector in a + # disaggregated decode-prefill setup. + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": + lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")) } # --8<-- [end:env-vars-definition] From 195a309afeed25f7a071950f60102c1c84ddabfb Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 18 Jun 2025 14:29:40 +0000 Subject: [PATCH 02/14] test abort request with timeout Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e30a250449aa..7b6e32edc577 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,10 +9,13 @@ import pytest +from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext +from vllm.llm_engine.llm_engine import LLM +from vllm.llm_engine.scheduler.scheduler import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config @@ -371,3 +374,64 @@ def test_concurrent_load_kv( if cnt_finished_reqs == total_reqs: return raise TimeoutError("Took too long to complete async handshake.") + + +def test_abort_timeout_on_prefiller(monkeypatch): + """ + Test lifecycle of an aborted Remote Prefill request hitting the timeout. + -----> P + | {process request} + <-\--- | {result is NOT delivered, eg proxy is down} + | + | + | {eventually free blocks} + """ + model_name = "Qwen/Qwen3-0.6B" + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + timeout = 6 + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + llm = LLM( + model=model_name, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + remote_prefill_opts = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + # Simulate sidecar request + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=1, + extra_args={"kv_transfer_params": remote_prefill_opts}) + scheduler = llm.llm_engine.engine_core.engine_core.scheduler + + padding = "Just making this request a little longer so that we're sure " + "we're not hitting the small-request lower bound beneath which we don't " + "actually trigger the whole kv transfer, but rather just recompute the " + "blocks on D." + _ = llm.generate([f"What is the capital of Japan? {padding}"], + sampling_params) + + # Request finished but not freed + assert '0' in scheduler.pending_kv_free_req_ids + # Some other request + _ = llm.generate([f"What is the capital of Italy? {padding}"], + sampling_params) + assert scheduler.pending_kv_free_req_ids == {"0", "1"} + + # Wait for timeout and trigger another scheduler loop + time.sleep(timeout) + _ = llm.generate([f"What is the capital of France? {padding}"], + sampling_params) + # Request-0 times out and is cleared! + assert '0' not in scheduler.pending_kv_free_req_ids From 17f44b88f2d6ec61c549dc9915bd6255dd2fcf1f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 26 Jun 2025 14:16:16 +0000 Subject: [PATCH 03/14] wip Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 7b6e32edc577..fe1c4e6b19f0 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,13 +9,13 @@ import pytest +from vllm import LLM from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext -from vllm.llm_engine.llm_engine import LLM -from vllm.llm_engine.scheduler.scheduler import SamplingParams +from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config From 676f3f63ae819905ac88cc863a330b9f7d269f1e Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 26 Jun 2025 16:12:01 +0000 Subject: [PATCH 04/14] remote timeout Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 11 ++-- .../kv_connector/v1/nixl_connector.py | 55 +++++++++++++------ 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fe1c4e6b19f0..0996192953d4 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -414,6 +414,8 @@ def test_abort_timeout_on_prefiller(monkeypatch): max_tokens=1, extra_args={"kv_transfer_params": remote_prefill_opts}) scheduler = llm.llm_engine.engine_core.engine_core.scheduler + req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks padding = "Just making this request a little longer so that we're sure " "we're not hitting the small-request lower bound beneath which we don't " @@ -423,15 +425,16 @@ def test_abort_timeout_on_prefiller(monkeypatch): sampling_params) # Request finished but not freed - assert '0' in scheduler.pending_kv_free_req_ids - # Some other request + assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks + # Some other request, 0 still not freed _ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) - assert scheduler.pending_kv_free_req_ids == {"0", "1"} + assert '0' in req_to_blocks + assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks # Wait for timeout and trigger another scheduler loop time.sleep(timeout) _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! - assert '0' not in scheduler.pending_kv_free_req_ids + assert '0' not in req_to_blocks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3159e8024107..608c5259edc1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy import math import queue import threading @@ -79,7 +80,8 @@ class ReqMeta: class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.requests: dict[ReqId, ReqMeta] = {} + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: set[str] = set() def add_new_req( self, @@ -87,7 +89,7 @@ def add_new_req( local_block_ids: list[int], kv_transfer_params: dict[str, Any], ): - self.requests[request_id] = ReqMeta( + self.reqs_to_recv[request_id] = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -194,10 +196,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) - # Requests that need to start recv. + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_send: set[str] = set() def get_num_new_matched_tokens( self, request: "Request", @@ -265,6 +268,9 @@ def update_state_after_alloc(self, request: "Request", assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. params["do_remote_prefill"] = False + elif params is not None and params.get("do_remote_decode"): + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send.add(request.request_id) def build_connector_meta( self, @@ -281,8 +287,10 @@ def build_connector_meta( kv_transfer_params=req.kv_transfer_params, ) + meta.reqs_to_send = copy.copy(self._reqs_need_send) # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_send.clear() return meta @@ -394,6 +402,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) + # Keep track of the time for requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -826,18 +836,23 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) - # Handle timeout - # now = time.perf_counter() - # for req_id, finish_time in self._reqs_to_send.items(): - # if finish_time == -1: - # # Request just finished, start timeout. - # self._reqs_to_send[req_id] = now - # elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: - # # Timeout exceed, abort request and clear. - # aborted_req_ids.add(req_id) - # if req_id in self._done_sending_count: - # self._done_sending_count[req_id] += self.world_size - # del self._reqs_to_send[req_id] + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + timed_out_requests: list[str] = [] + for req_id, finish_time in self._reqs_to_send.items(): + if finish_time < 0: + # Request just finished, start timeout. + self._reqs_to_send[req_id] = now + elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: + # Timeout exceed, clear the request blocks. + timed_out_requests.append(req_id) + + for req_id in timed_out_requests: + # Skip communication with other ranks, but + if self.tp_rank == 0: + self._done_sending_count[req_id] += self.world_size + done_sending.add(req_id) + del self._reqs_to_send[req_id] if self.world_size == 1: return done_sending, done_recving @@ -870,7 +885,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: all_done_sending: set[str] = set() for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: + if self._done_sending_count[req_id] >= self.world_size: del self._done_sending_count[req_id] all_done_sending.add(req_id) @@ -900,6 +915,7 @@ def _get_new_notifs(self) -> set[str]: tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] + del self._reqs_to_send[req_id] return notified_req_ids def _pop_done_transfers( @@ -934,7 +950,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - for req_id, meta in metadata.requests.items(): + for req_id, meta in metadata.reqs_to_recv.items(): remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " @@ -956,6 +972,11 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # Track the request that are waiting to be read and abort on timeout. + # Set to -1 so that timeout does not depend on model latency. + for req_id in metadata.reqs_to_send: + self._reqs_to_send[req_id] = -1 + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", From 528898f03f6f9aa0542ac11825c7aa0562137468 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 27 Jun 2025 13:13:14 +0000 Subject: [PATCH 05/14] time.monotonic Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 608c5259edc1..f3a00088d372 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -837,7 +837,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_sending), len(done_recving)) # Handle timeout to avoid stranding blocks on remote. - now = time.perf_counter() + now = time.monotonic() timed_out_requests: list[str] = [] for req_id, finish_time in self._reqs_to_send.items(): if finish_time < 0: @@ -1081,8 +1081,7 @@ def _read_blocks(self, local_block_ids: list[int], # Use handle to check completion in future step(). # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append( - (handle, time.perf_counter())) + self._recving_transfers[request_id].append((handle, time.monotonic())) def _get_block_descs_ids(self, engine_id: str, From e04945a80782db8098eeab1efa0cadea96a5e83f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 30 Jun 2025 10:35:04 +0000 Subject: [PATCH 06/14] review optimization Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f3a00088d372..1e062d972f5f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import copy import math import queue import threading @@ -81,7 +80,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} - self.reqs_to_send: set[str] = set() + self.reqs_to_send: dict[ReqId, float] = {} def add_new_req( self, @@ -200,7 +199,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_send: set[str] = set() + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -268,9 +268,6 @@ def update_state_after_alloc(self, request: "Request", assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. params["do_remote_prefill"] = False - elif params is not None and params.get("do_remote_decode"): - # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send.add(request.request_id) def build_connector_meta( self, @@ -287,10 +284,11 @@ def build_connector_meta( kv_transfer_params=req.kv_transfer_params, ) - meta.reqs_to_send = copy.copy(self._reqs_need_send) # Clear the list once workers start the transfers self._reqs_need_recv.clear() - self._reqs_need_send.clear() + + meta.reqs_to_send = self._reqs_need_send + self._reqs_need_send = {} return meta @@ -333,6 +331,13 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks and params.get("do_remote_decode"): + now = time.monotonic() + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[ + request. + request_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -402,7 +407,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) - # Keep track of the time for requests that are waiting to be sent. + # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Complete transfer tracker. Used by the rank 0 to track finished @@ -838,21 +843,13 @@ def get_finished(self) -> tuple[set[str], set[str]]: # Handle timeout to avoid stranding blocks on remote. now = time.monotonic() - timed_out_requests: list[str] = [] - for req_id, finish_time in self._reqs_to_send.items(): - if finish_time < 0: - # Request just finished, start timeout. - self._reqs_to_send[req_id] = now - elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: - # Timeout exceed, clear the request blocks. - timed_out_requests.append(req_id) - - for req_id in timed_out_requests: - # Skip communication with other ranks, but - if self.tp_rank == 0: - self._done_sending_count[req_id] += self.world_size - done_sending.add(req_id) + while self._reqs_to_send: + req_id, expires = next(iter(self._reqs_to_send.items())) + # Sorted dict, oldest request are put first so we can exit early. + if now < expires: + break del self._reqs_to_send[req_id] + done_sending.add(req_id) if self.world_size == 1: return done_sending, done_recving @@ -972,10 +969,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) - # Track the request that are waiting to be read and abort on timeout. - # Set to -1 so that timeout does not depend on model latency. - for req_id in metadata.reqs_to_send: - self._reqs_to_send[req_id] = -1 + # Add to requests that are waiting to be read and track expiration. + self._reqs_to_send.update(metadata.reqs_to_send) def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( From 101044fb3e3f67229928cd3df233ec25f1a251a3 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 30 Jun 2025 10:47:07 +0000 Subject: [PATCH 07/14] typo Signed-off-by: NickLucche --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1e062d972f5f..0f424096c4ab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -845,7 +845,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: now = time.monotonic() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) - # Sorted dict, oldest request are put first so we can exit early. + # Sorted dict, oldest requests are put first so we can exit early. if now < expires: break del self._reqs_to_send[req_id] From ede261646b5c8fe9dbc200426f9d6734399100e8 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 1 Jul 2025 16:49:11 +0000 Subject: [PATCH 08/14] review + check req TTL decoder side and align clocks Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 64 +++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f424096c4ab..5c5eee596bbe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -64,6 +64,7 @@ class NixlAgentMetadata( num_blocks: int block_len: int attn_backend_name: str + remote_node_time: Optional[float] = None @dataclass @@ -74,6 +75,7 @@ class ReqMeta: remote_port: int remote_engine_id: str tp_size: int + request_ttl: float class NixlConnectorMetadata(KVConnectorMetadata): @@ -96,6 +98,7 @@ def add_new_req( remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), + request_ttl=kv_transfer_params.get("request_ttl", -1), ) @@ -331,12 +334,10 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 - if delay_free_blocks and params.get("do_remote_decode"): - now = time.monotonic() + if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[ - request. - request_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) return delay_free_blocks, dict( do_remote_prefill=True, @@ -345,7 +346,8 @@ def request_finished( remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + request_ttl=self._reqs_need_send.get(request.request_id, -1)) class NixlConnectorWorker: @@ -457,6 +459,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + # Map of remote agent name -> time offset to keep clocks synced. + self._remote_agent_time_offsets: dict[str, float] = {} + self._reqs_expired_ttl: set[ReqId] = set() + def __del__(self): """Cleanup background threads on destruction.""" self._handshake_initiation_executor.shutdown(wait=False) @@ -488,14 +494,16 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, if msg != GET_META_MSG: logger.warning( "Connection listener got unexpected message %s", msg) + + # Add current node time to the metadata for clock sync with D. + metadata.remote_node_time = time.perf_counter() + encoded_data = encoder.encode(metadata) sock.send_multipart((identity, b"", encoded_data)) def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" - start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. @@ -503,12 +511,19 @@ def _nixl_handshake(self, host: str, port: int, def handshake(path: str, rank: int) -> str: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: + start_time = time.perf_counter() sock.send(GET_META_MSG) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) + metadata: NixlAgentMetadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() + # "Sync" clocks between local and remote by registering offset. + rtt = got_metadata_time - start_time + assert metadata.remote_node_time + self._remote_agent_time_offsets[metadata.engine_id] = ( + metadata.remote_node_time + rtt / 2 - got_metadata_time) + # Register Remote agent. remote_agent_name = self.add_remote_agent( metadata, rank, remote_tp_size) @@ -842,7 +857,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_sending), len(done_recving)) # Handle timeout to avoid stranding blocks on remote. - now = time.monotonic() + now = time.perf_counter() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) # Sorted dict, oldest requests are put first so we can exit early. @@ -851,6 +866,12 @@ def get_finished(self) -> tuple[set[str], set[str]]: del self._reqs_to_send[req_id] done_sending.add(req_id) + # Handle remote requests with expired TTL without attempting to read. + while self._reqs_expired_ttl: + req_id = next(iter(self._reqs_expired_ttl)) + done_recving.add(req_id) + self._reqs_expired_ttl.remove(req_id) + if self.world_size == 1: return done_sending, done_recving @@ -949,11 +970,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.reqs_to_recv.items(): remote_engine_id = meta.remote_engine_id - logger.debug( - "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - remote_engine_id, len(meta.local_block_ids), - len(meta.remote_block_ids)) if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -973,9 +989,20 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send.update(metadata.reqs_to_send) def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + # Make sure request TTL is not expired before reading. + assert self._remote_agent_time_offsets[ + meta.remote_engine_id] is not None + remote_offset = self._remote_agent_time_offsets[meta.remote_engine_id] + if time.perf_counter() + remote_offset > meta.request_ttl: + logger.warning("Request remote TTL expired for request %s", req_id) + self._reqs_expired_ttl.add(req_id) + return + logger.debug( - "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, req_id) + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1076,7 +1103,8 @@ def _read_blocks(self, local_block_ids: list[int], # Use handle to check completion in future step(). # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append((handle, time.monotonic())) + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) def _get_block_descs_ids(self, engine_id: str, From cd401430a09a29d34291aa095f06fb2fd3fe19f3 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 1 Jul 2025 17:05:16 +0000 Subject: [PATCH 09/14] test ttl Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 128 +++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 0996192953d4..730afd17dbec 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -162,9 +162,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" - def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): + def __init__(self, + *args, + hand_shake_latency: float = 1.8, + remote_agent_time_offset: float = 0.0, + **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency + self._remote_agent_time_offsets = { + self.REMOTE_ENGINE_ID: remote_agent_time_offset + } def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: @@ -375,6 +382,125 @@ def test_concurrent_load_kv( return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_ttl_expiration_on_decoder(self, dist_init): + """ + Test that decoder-side TTL expiration works correctly. + + This test verifies that: + 1. Requests with expired TTL are not processed for KV transfer (no _read_blocks called) + 2. Expired requests are automatically marked as finished + 3. Clock synchronization offset is properly handled (remote is N seconds ahead) + """ #noqa: E501 + # Remote is 100 seconds ahead. + remote_agent_time_offset = 100.0 + vllm_config = create_vllm_config() + + # Test worker role in decode server + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + + class TTLTestNixlConnectorWorker(FakeNixlConnectorWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._read_blocks_called = set() + + def _read_blocks(self, local_block_ids, remote_block_ids, + dst_engine_id, request_id): + # Override to track if _read_blocks was called but don't + # actually read blocks + self._read_blocks_called.add(request_id) + + connector.connector_worker = TTLTestNixlConnectorWorker( + vllm_config, + connector.engine_id, + remote_agent_time_offset=remote_agent_time_offset) + + # Ensure the remote agent is already registered (skip handshake) + connector.connector_worker._remote_agents[ + FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = { + 0: "test_agent" + } + + current_time = time.perf_counter() + + # Test Case 1: Request with expired TTL + expired_request_id = "expired_req" + expired_ttl = current_time # TTL expired (remote is ahead) + + metadata_expired = NixlConnectorMetadata() + metadata_expired.add_new_req( + request_id=expired_request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "tp_size": 1, + "request_ttl": expired_ttl, + }) + + # Test Case 2: Request with valid TTL + valid_request_id = "valid_req" + # 200 seconds from now + valid_ttl = current_time + remote_agent_time_offset + 200 + + metadata_valid = NixlConnectorMetadata() + metadata_valid.add_new_req( + request_id=valid_request_id, + local_block_ids=[7, 8, 9], + kv_transfer_params={ + "remote_block_ids": [10, 11, 12], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "tp_size": 1, + "request_ttl": valid_ttl, + }) + + # Process expired request + connector.bind_connector_metadata(metadata_expired) + connector.start_load_kv( + ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + )) + + # Check that expired request was added to _reqs_expired_ttl + assert expired_request_id in \ + connector.connector_worker._reqs_expired_ttl + + # Check that _read_blocks was NOT called for expired request + assert expired_request_id not in \ + connector.connector_worker._read_blocks_called + + # Check that expired request is marked as finished + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert expired_request_id in done_recving + assert expired_request_id not in \ + connector.connector_worker._reqs_expired_ttl # Should be removed + + # Process valid request + connector.bind_connector_metadata(metadata_valid) + connector.start_load_kv( + ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + )) + + # Check that valid request was NOT added to _reqs_expired_ttl + assert valid_request_id not in \ + connector.connector_worker._reqs_expired_ttl + + # Check that _read_blocks WAS called for valid request + assert valid_request_id in \ + connector.connector_worker._read_blocks_called + def test_abort_timeout_on_prefiller(monkeypatch): """ From a13822377870c175eef12e442c48d2ffe4496648 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 3 Jul 2025 08:51:05 +0000 Subject: [PATCH 10/14] Revert "test ttl" This reverts commit 0d09b8b01c8f368e96f4a9765b5b9be1cbc06b43. Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 128 +----------------- 1 file changed, 1 insertion(+), 127 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 730afd17dbec..0996192953d4 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -162,16 +162,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" - def __init__(self, - *args, - hand_shake_latency: float = 1.8, - remote_agent_time_offset: float = 0.0, - **kwargs): + def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - self._remote_agent_time_offsets = { - self.REMOTE_ENGINE_ID: remote_agent_time_offset - } def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: @@ -382,125 +375,6 @@ def test_concurrent_load_kv( return raise TimeoutError("Took too long to complete async handshake.") - @patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) - def test_ttl_expiration_on_decoder(self, dist_init): - """ - Test that decoder-side TTL expiration works correctly. - - This test verifies that: - 1. Requests with expired TTL are not processed for KV transfer (no _read_blocks called) - 2. Expired requests are automatically marked as finished - 3. Clock synchronization offset is properly handled (remote is N seconds ahead) - """ #noqa: E501 - # Remote is 100 seconds ahead. - remote_agent_time_offset = 100.0 - vllm_config = create_vllm_config() - - # Test worker role in decode server - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - - class TTLTestNixlConnectorWorker(FakeNixlConnectorWorker): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._read_blocks_called = set() - - def _read_blocks(self, local_block_ids, remote_block_ids, - dst_engine_id, request_id): - # Override to track if _read_blocks was called but don't - # actually read blocks - self._read_blocks_called.add(request_id) - - connector.connector_worker = TTLTestNixlConnectorWorker( - vllm_config, - connector.engine_id, - remote_agent_time_offset=remote_agent_time_offset) - - # Ensure the remote agent is already registered (skip handshake) - connector.connector_worker._remote_agents[ - FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = { - 0: "test_agent" - } - - current_time = time.perf_counter() - - # Test Case 1: Request with expired TTL - expired_request_id = "expired_req" - expired_ttl = current_time # TTL expired (remote is ahead) - - metadata_expired = NixlConnectorMetadata() - metadata_expired.add_new_req( - request_id=expired_request_id, - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "tp_size": 1, - "request_ttl": expired_ttl, - }) - - # Test Case 2: Request with valid TTL - valid_request_id = "valid_req" - # 200 seconds from now - valid_ttl = current_time + remote_agent_time_offset + 200 - - metadata_valid = NixlConnectorMetadata() - metadata_valid.add_new_req( - request_id=valid_request_id, - local_block_ids=[7, 8, 9], - kv_transfer_params={ - "remote_block_ids": [10, 11, 12], - "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "tp_size": 1, - "request_ttl": valid_ttl, - }) - - # Process expired request - connector.bind_connector_metadata(metadata_expired) - connector.start_load_kv( - ForwardContext( - no_compile_layers={}, - attn_metadata={}, - virtual_engine=0, - )) - - # Check that expired request was added to _reqs_expired_ttl - assert expired_request_id in \ - connector.connector_worker._reqs_expired_ttl - - # Check that _read_blocks was NOT called for expired request - assert expired_request_id not in \ - connector.connector_worker._read_blocks_called - - # Check that expired request is marked as finished - _, done_recving = connector.get_finished(finished_req_ids=set()) - assert expired_request_id in done_recving - assert expired_request_id not in \ - connector.connector_worker._reqs_expired_ttl # Should be removed - - # Process valid request - connector.bind_connector_metadata(metadata_valid) - connector.start_load_kv( - ForwardContext( - no_compile_layers={}, - attn_metadata={}, - virtual_engine=0, - )) - - # Check that valid request was NOT added to _reqs_expired_ttl - assert valid_request_id not in \ - connector.connector_worker._reqs_expired_ttl - - # Check that _read_blocks WAS called for valid request - assert valid_request_id in \ - connector.connector_worker._read_blocks_called - def test_abort_timeout_on_prefiller(monkeypatch): """ From afa5852536a357f8de9bf7989b56ca284d3de34b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 3 Jul 2025 08:51:07 +0000 Subject: [PATCH 11/14] Revert "review + check req TTL decoder side and align clocks" This reverts commit b7d5c64a024d8d7a4878f9fa90776e1fd849bbee. Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 64 ++++++------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5c5eee596bbe..0f424096c4ab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -64,7 +64,6 @@ class NixlAgentMetadata( num_blocks: int block_len: int attn_backend_name: str - remote_node_time: Optional[float] = None @dataclass @@ -75,7 +74,6 @@ class ReqMeta: remote_port: int remote_engine_id: str tp_size: int - request_ttl: float class NixlConnectorMetadata(KVConnectorMetadata): @@ -98,7 +96,6 @@ def add_new_req( remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), - request_ttl=kv_transfer_params.get("request_ttl", -1), ) @@ -334,10 +331,12 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 - if delay_free_blocks: + if delay_free_blocks and params.get("do_remote_decode"): + now = time.monotonic() # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + self._reqs_need_send[ + request. + request_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT return delay_free_blocks, dict( do_remote_prefill=True, @@ -346,8 +345,7 @@ def request_finished( remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size, - request_ttl=self._reqs_need_send.get(request.request_id, -1)) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size) class NixlConnectorWorker: @@ -459,10 +457,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) - # Map of remote agent name -> time offset to keep clocks synced. - self._remote_agent_time_offsets: dict[str, float] = {} - self._reqs_expired_ttl: set[ReqId] = set() - def __del__(self): """Cleanup background threads on destruction.""" self._handshake_initiation_executor.shutdown(wait=False) @@ -494,16 +488,14 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, if msg != GET_META_MSG: logger.warning( "Connection listener got unexpected message %s", msg) - - # Add current node time to the metadata for clock sync with D. - metadata.remote_node_time = time.perf_counter() - encoded_data = encoder.encode(metadata) sock.send_multipart((identity, b"", encoded_data)) def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" + start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. @@ -511,19 +503,12 @@ def _nixl_handshake(self, host: str, port: int, def handshake(path: str, rank: int) -> str: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: - start_time = time.perf_counter() sock.send(GET_META_MSG) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata: NixlAgentMetadata = decoder.decode(metadata_bytes) + metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - # "Sync" clocks between local and remote by registering offset. - rtt = got_metadata_time - start_time - assert metadata.remote_node_time - self._remote_agent_time_offsets[metadata.engine_id] = ( - metadata.remote_node_time + rtt / 2 - got_metadata_time) - # Register Remote agent. remote_agent_name = self.add_remote_agent( metadata, rank, remote_tp_size) @@ -857,7 +842,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_sending), len(done_recving)) # Handle timeout to avoid stranding blocks on remote. - now = time.perf_counter() + now = time.monotonic() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) # Sorted dict, oldest requests are put first so we can exit early. @@ -866,12 +851,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: del self._reqs_to_send[req_id] done_sending.add(req_id) - # Handle remote requests with expired TTL without attempting to read. - while self._reqs_expired_ttl: - req_id = next(iter(self._reqs_expired_ttl)) - done_recving.add(req_id) - self._reqs_expired_ttl.remove(req_id) - if self.world_size == 1: return done_sending, done_recving @@ -970,6 +949,11 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.reqs_to_recv.items(): remote_engine_id = meta.remote_engine_id + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -989,20 +973,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send.update(metadata.reqs_to_send) def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - # Make sure request TTL is not expired before reading. - assert self._remote_agent_time_offsets[ - meta.remote_engine_id] is not None - remote_offset = self._remote_agent_time_offsets[meta.remote_engine_id] - if time.perf_counter() + remote_offset > meta.request_ttl: - logger.warning("Request remote TTL expired for request %s", req_id) - self._reqs_expired_ttl.add(req_id) - return - logger.debug( - "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - meta.remote_engine_id, len(meta.local_block_ids), - len(meta.remote_block_ids)) + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1103,8 +1076,7 @@ def _read_blocks(self, local_block_ids: list[int], # Use handle to check completion in future step(). # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append( - (handle, time.perf_counter())) + self._recving_transfers[request_id].append((handle, time.monotonic())) def _get_block_descs_ids(self, engine_id: str, From 772e7f4f311b71e493759a10829dd1a7e0c60317 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 3 Jul 2025 08:53:25 +0000 Subject: [PATCH 12/14] reivew Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f424096c4ab..67adb3e8a3c9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -331,12 +331,10 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 - if delay_free_blocks and params.get("do_remote_decode"): - now = time.monotonic() + if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[ - request. - request_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT return delay_free_blocks, dict( do_remote_prefill=True, @@ -842,7 +840,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_sending), len(done_recving)) # Handle timeout to avoid stranding blocks on remote. - now = time.monotonic() + now = time.perf_counter() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) # Sorted dict, oldest requests are put first so we can exit early. @@ -1076,7 +1074,8 @@ def _read_blocks(self, local_block_ids: list[int], # Use handle to check completion in future step(). # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append((handle, time.monotonic())) + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) def _get_block_descs_ids(self, engine_id: str, From e249b4ca2a050384f18ae8408a70d7f7df68e40d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 7 Jul 2025 10:06:33 +0000 Subject: [PATCH 13/14] mock nixl Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_nixl_connector.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 0996192953d4..94583b1b4e3c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -376,6 +376,9 @@ def test_concurrent_load_kv( raise TimeoutError("Took too long to complete async handshake.") +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) def test_abort_timeout_on_prefiller(monkeypatch): """ Test lifecycle of an aborted Remote Prefill request hitting the timeout. From 855c445bcf02ff21c536edb9e4e6251d7d74fb2f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 7 Jul 2025 13:42:22 +0000 Subject: [PATCH 14/14] update tests Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_nixl_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 94583b1b4e3c..e18c4975a322 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -44,9 +44,9 @@ def test_basic_interface(): assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 1 - assert request_id in kv_connector_metadata.requests - req_meta = kv_connector_metadata.requests[request_id] + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert request_id in kv_connector_metadata.reqs_to_recv + req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. @@ -81,7 +81,7 @@ def test_prompt_less_than_block_size(): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 0 + assert len(kv_connector_metadata.reqs_to_recv) == 0 # This request should be scheduled regularly. assert len(scheduler_output.scheduled_new_reqs) == 1