diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b00be7b83e12..ab9729aae2e9 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -7,13 +7,6 @@ from typing import Optional from unittest.mock import patch -import pytest - -try: - from nixl._api import nixl_agent as NixlWrapper -except ImportError: - NixlWrapper = None - from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) @@ -92,7 +85,8 @@ def test_prompt_less_than_block_size(): class FakeNixlWrapper: """Mock implementation of NixlWrapper for testing. - We don't inherit from NixlWrapper because NixlWrapper could be None. + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. """ AGENT_METADATA = b"fake_agent_metadata" @@ -167,7 +161,7 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int): + def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -177,7 +171,7 @@ def _nixl_handshake(self, host: str, port: int): self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks - self.add_remote_agent( + remote_agent_name = self.add_remote_agent( NixlAgentMetadata( engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, @@ -187,40 +181,101 @@ def _nixl_handshake(self, host: str, port: int): block_len=self.block_len, attn_backend_name=self.backend_name, )) - - -@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed") -@patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) -def test_multi_xfer_one_engine( - # dist_init is a fixture that initializes the distributed environment. - dist_init): - """Test case where multiple xfers are initiated to the same engine. - - This test triggers the connector to load remote KV for the same - `request_id`. The transfer is not done immediately due to - `set_cycles_before_xfer_done`, so there is a state where there are multiple - transfer states for the same `request_id`, and `get_finished` should handle - it correctly (wait for all transfers to be done). - """ - vllm_config = create_vllm_config() - - request_id = "req_id" - - # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeNixlConnectorWorker(vllm_config, - connector.engine_id, - hand_shake_latency=0) - assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) - connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) - for i in range(4): + return {0: remote_agent_name} + + +class TestNixlHandshake: + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_multi_xfer_one_engine( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test case where multiple xfers are initiated to the same engine. + + This test triggers the connector to load remote KV for the same + `request_id`. The transfer is not done immediately due to + `set_cycles_before_xfer_done`, so there is a state where there are + multiple transfer states for the same `request_id`, and `get_finished` + should handle it correctly (wait for all transfers to be done). + """ + vllm_config = create_vllm_config() + + request_id = "req_id" + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + assert isinstance(connector.connector_worker.nixl_wrapper, + FakeNixlWrapper) + connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + num_xfers = 4 + while True: + # For the same request_id, initiate multiple xfers across different + # round of `execute_model` calls. + metadata = NixlConnectorMetadata() + if num_xfers > 0: + num_xfers -= 1 + metadata.add_new_req( + request_id=request_id, + local_block_ids=[ + num_xfers + 1, num_xfers + 2, num_xfers + 3 + ], + kv_transfer_params={ + "remote_block_ids": + [num_xfers + 4, num_xfers + 5, num_xfers + 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": + "localhost", + "remote_port": + 1234, + }) + connector.bind_connector_metadata(metadata) + + # Mimic maybe_setup_kv_connector in gpu_model_runner. + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + + # Mimic get_finished_kv_transfers in gpu_model_runner. + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + assert request_id in done_recving + break + + connector.clear_connector_metadata() + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_async_load_kv( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test that NixlConnector's start_load_kv should be non-blocking.""" + + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id) metadata = NixlConnectorMetadata() - metadata.add_new_req(request_id=request_id, - local_block_ids=[i + 1, i + 2, i + 3], + metadata.add_new_req(request_id="id", + local_block_ids=[1, 2, 3], kv_transfer_params={ - "remote_block_ids": [i + 4, i + 5, i + 6], + "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", @@ -228,19 +283,74 @@ def test_multi_xfer_one_engine( }) connector.bind_connector_metadata(metadata) - dummy_ctx = ForwardContext( - no_compile_layers={}, - attn_metadata={}, - virtual_engine=0, - ) - _before_load = time.perf_counter() - connector.start_load_kv(dummy_ctx) - _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" - - while True: - _, done_recving = connector.get_finished(finished_req_ids=set()) - if len(done_recving) > 0: - assert request_id in done_recving - break + timeout = 2.5 + start = time.perf_counter() + while time.perf_counter() - start < timeout: + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + time.sleep(0.5) # backoff for the async handshake to complete. + connector.bind_connector_metadata(NixlConnectorMetadata()) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + return + raise TimeoutError("Took too long to complete async handshake.") + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_concurrent_load_kv( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test that multiple start_load_kv calls should occur concurrently.""" + + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id) + metadata = NixlConnectorMetadata() + total_reqs = 5 + for i in range(total_reqs): + metadata.add_new_req(request_id=f"id_{i}", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + }) + connector.bind_connector_metadata(metadata) + + timeout = 2.5 * total_reqs + cnt_finished_reqs = 0 + start = time.perf_counter() + while time.perf_counter() - start < timeout: + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + time.sleep(0.5) # backoff for the async handshake to complete. + connector.bind_connector_metadata(NixlConnectorMetadata()) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + cnt_finished_reqs += len(done_recving) + if cnt_finished_reqs == total_reqs: + return + raise TimeoutError("Took too long to complete async handshake.") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 65bdd7ae29d5..a962a9241d73 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,11 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import math +import queue import threading import time import uuid from collections import defaultdict from collections.abc import Iterator +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -23,6 +25,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.distributed.utils import divide +from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend from vllm.utils import make_zmq_path, make_zmq_socket, round_down @@ -31,7 +34,6 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -71,7 +73,7 @@ class ReqMeta: remote_block_ids: list[int] remote_host: str remote_port: int - remote_engine_id: str + remote_engine_id: EngineId class NixlConnectorMetadata(KVConnectorMetadata): @@ -81,7 +83,7 @@ def __init__(self): def add_new_req( self, - request_id: str, + request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], ): @@ -102,7 +104,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + self.connector_scheduler: Optional[NixlConnectorScheduler] = \ NixlConnectorScheduler(vllm_config, self.engine_id) self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: @@ -186,7 +188,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) @@ -343,7 +345,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Each TP rank listens/queries on the base_port + tp_rank. self.side_channel_port: int = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) # Metadata. @@ -386,8 +388,17 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._done_sending_count: defaultdict[ReqId, int] = defaultdict(lambda: 0) - # Background thread for establishing new connections. + # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # Background thread for initializing new NIXL handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # NIXL is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-nixl-handshake-initiator") + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -416,6 +427,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + def __del__(self): + """Cleanup background threads on destruction.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t: + self._nixl_handshake_listener_t.join(timeout=0) + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, base_port: int, @@ -443,7 +460,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int): + def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -452,7 +469,7 @@ def _nixl_handshake(self, host: str, port: int): # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> NixlAgentMetadata: + def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -462,19 +479,20 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: got_metadata_time = time.perf_counter() # Register Remote agent. - self.add_remote_agent(metadata, rank) + remote_agent_name = self.add_remote_agent(metadata, rank) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time) - return metadata + return metadata, remote_agent_name # Handshake with remote agent-rank0 first to get the tp_size of remote path = make_zmq_path("tcp", host, port) logger.debug("Querying master rank metadata on path: %s", path) - metadata = handshake(path, 0) + rank_to_agent_name: dict[int, str] = {} + metadata, rank_to_agent_name[0] = handshake(path, 0) # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. @@ -484,7 +502,10 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) - _ = handshake(path, p_remote_rank) + _, rank_to_agent_name[p_remote_rank] = handshake( + path, p_remote_rank) + + return rank_to_agent_name def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -621,11 +642,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): daemon=True, name="nixl_handshake_listener") self._nixl_handshake_listener_t.start() - ready_event.wait() + ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0): + remote_tp_rank: int = 0) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. @@ -666,8 +687,8 @@ def add_remote_agent(self, """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery - if remote_tp_rank in self._remote_agents.get(engine_id, ()): - return + if remote_tp_rank in self._remote_agents.get(engine_id, {}): + return self._remote_agents[engine_id][remote_tp_rank] if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size @@ -677,9 +698,8 @@ def add_remote_agent(self, # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name - self._remote_agents[engine_id][ - remote_tp_rank] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) + remote_agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. @@ -708,8 +728,9 @@ def add_remote_agent(self, "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) - assert self.block_size == remote_block_size, "Remote P worker with " \ - "different block size is not supported" + assert self.block_size == remote_block_size, ( + "Remote P worker with different block size is not supported " + f"{self.block_size=} {remote_block_size=}") # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -748,7 +769,9 @@ def add_remote_agent(self, descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][remote_tp_rank], descs) + remote_agent_name, descs) + + return remote_agent_name def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -866,33 +889,68 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.requests.items(): + remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - meta.remote_engine_id, len(meta.local_block_ids), + remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_host=meta.remote_host, - remote_port=meta.remote_port, - ) + if remote_engine_id not in self._remote_agents: + # Being optimistic to assume engine is usually ready, apply + # lock only when the optimistic check fails. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, + meta.remote_port) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], + eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception( + "Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], + entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + continue + self._read_blocks_for_req(req_id, meta) + + # Start transfers for requests whose handshakes have now finished. + while not self._ready_requests.empty(): + self._read_blocks_for_req(*self._ready_requests.get_nowait()) + + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) def _read_blocks( self, local_block_ids: list[int], remote_block_ids: list[int], - remote_host: str, - remote_port: int, dst_engine_id: str, request_id: str, ): - # NOTE(rob): this takes ~2s. We need to get this off the hotpath. - if dst_engine_id not in self._remote_agents: - self._nixl_handshake(remote_host, remote_port) - # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the