diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e30a250449aa..e18c4975a322 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,10 +9,13 @@ import pytest +from vllm import LLM +from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext +from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config @@ -41,9 +44,9 @@ def test_basic_interface(): assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 1 - assert request_id in kv_connector_metadata.requests - req_meta = kv_connector_metadata.requests[request_id] + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert request_id in kv_connector_metadata.reqs_to_recv + req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. @@ -78,7 +81,7 @@ def test_prompt_less_than_block_size(): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 0 + assert len(kv_connector_metadata.reqs_to_recv) == 0 # This request should be scheduled regularly. assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -371,3 +374,70 @@ def test_concurrent_load_kv( if cnt_finished_reqs == total_reqs: return raise TimeoutError("Took too long to complete async handshake.") + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_abort_timeout_on_prefiller(monkeypatch): + """ + Test lifecycle of an aborted Remote Prefill request hitting the timeout. + -----> P + | {process request} + <-\--- | {result is NOT delivered, eg proxy is down} + | + | + | {eventually free blocks} + """ + model_name = "Qwen/Qwen3-0.6B" + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + timeout = 6 + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + llm = LLM( + model=model_name, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + remote_prefill_opts = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + # Simulate sidecar request + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=1, + extra_args={"kv_transfer_params": remote_prefill_opts}) + scheduler = llm.llm_engine.engine_core.engine_core.scheduler + req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks + + padding = "Just making this request a little longer so that we're sure " + "we're not hitting the small-request lower bound beneath which we don't " + "actually trigger the whole kv transfer, but rather just recompute the " + "blocks on D." + _ = llm.generate([f"What is the capital of Japan? {padding}"], + sampling_params) + + # Request finished but not freed + assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks + # Some other request, 0 still not freed + _ = llm.generate([f"What is the capital of Italy? {padding}"], + sampling_params) + assert '0' in req_to_blocks + assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks + + # Wait for timeout and trigger another scheduler loop + time.sleep(timeout) + _ = llm.generate([f"What is the capital of France? {padding}"], + sampling_params) + # Request-0 times out and is cleared! + assert '0' not in req_to_blocks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 56ae1acf8571..67adb3e8a3c9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -79,7 +79,8 @@ class ReqMeta: class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.requests: dict[ReqId, ReqMeta] = {} + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} def add_new_req( self, @@ -87,7 +88,7 @@ def add_new_req( local_block_ids: list[int], kv_transfer_params: dict[str, Any], ): - self.requests[request_id] = ReqMeta( + self.reqs_to_recv[request_id] = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -194,10 +195,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) - # Requests that need to start recv. + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -284,6 +287,9 @@ def build_connector_meta( # Clear the list once workers start the transfers self._reqs_need_recv.clear() + meta.reqs_to_send = self._reqs_need_send + self._reqs_need_send = {} + return meta def request_finished( @@ -325,6 +331,11 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -394,6 +405,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) + # Track the expiration time of requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -826,6 +839,16 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + while self._reqs_to_send: + req_id, expires = next(iter(self._reqs_to_send.items())) + # Sorted dict, oldest requests are put first so we can exit early. + if now < expires: + break + del self._reqs_to_send[req_id] + done_sending.add(req_id) + if self.world_size == 1: return done_sending, done_recving @@ -857,7 +880,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: all_done_sending: set[str] = set() for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: + if self._done_sending_count[req_id] >= self.world_size: del self._done_sending_count[req_id] all_done_sending.add(req_id) @@ -887,6 +910,7 @@ def _get_new_notifs(self) -> set[str]: tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] + del self._reqs_to_send[req_id] return notified_req_ids def _pop_done_transfers( @@ -921,7 +945,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - for req_id, meta in metadata.requests.items(): + for req_id, meta in metadata.reqs_to_recv.items(): remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " @@ -943,6 +967,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # Add to requests that are waiting to be read and track expiration. + self._reqs_to_send.update(metadata.reqs_to_send) + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", diff --git a/vllm/envs.py b/vllm/envs.py index 0cc6792d72bb..ec6a4896774f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -138,6 +138,7 @@ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 def get_default_cache_root(): @@ -953,7 +954,14 @@ def get_vllm_port() -> Optional[int]: # generations on machines < 100 for compressed-tensors # models "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + + # Time (in seconds) after which the KV cache on the producer side is + # automatically cleared if no READ notification is received from the + # consumer. This is only applicable when using NixlConnector in a + # disaggregated decode-prefill setup. + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": + lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")) } # --8<-- [end:env-vars-definition]