diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index ab9729aae2e9..e30a250449aa 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -7,6 +7,8 @@ from typing import Optional from unittest.mock import patch +import pytest + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) @@ -161,7 +163,8 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -177,10 +180,10 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - tp_size=1, block_len=self.block_len, attn_backend_name=self.backend_name, - )) + ), + remote_tp_size=remote_tp_size) return {0: remote_agent_name} @@ -233,6 +236,8 @@ def test_multi_xfer_one_engine( "localhost", "remote_port": 1234, + "remote_tp_size": + 1, }) connector.bind_connector_metadata(metadata) @@ -259,13 +264,23 @@ def test_multi_xfer_one_engine( @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper) + @pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [ + (1, 1), + (2, 1), + (4, 2), + (4, 4), + ]) def test_async_load_kv( - self, - # dist_init is a fixture that initializes the distributed environment. - dist_init): + self, + # Fixture that initializes the distributed environment. + dist_init, + # Simulate consumer-producer TP sizes. + decode_tp_size, + prefill_tp_size): """Test that NixlConnector's start_load_kv should be non-blocking.""" vllm_config = create_vllm_config() + vllm_config.parallel_config.tensor_parallel_size = decode_tp_size # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) @@ -280,6 +295,7 @@ def test_async_load_kv( FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", "remote_port": 1234, + "remote_tp_size": prefill_tp_size, }) connector.bind_connector_metadata(metadata) @@ -329,6 +345,7 @@ def test_concurrent_load_kv( FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", "remote_port": 1234, + "remote_tp_size": 1, }) connector.bind_connector_metadata(metadata) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a962a9241d73..2f7334706246 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -62,7 +62,6 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - tp_size: int block_len: int attn_backend_name: str @@ -73,7 +72,8 @@ class ReqMeta: remote_block_ids: list[int] remote_host: str remote_port: int - remote_engine_id: EngineId + remote_engine_id: str + tp_size: int class NixlConnectorMetadata(KVConnectorMetadata): @@ -93,6 +93,8 @@ def add_new_req( remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), ) @@ -317,7 +319,7 @@ def request_finished( remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - ) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size) class NixlConnectorWorker: @@ -460,7 +462,8 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -469,7 +472,7 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: + def handshake(path: str, rank: int) -> str: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -479,33 +482,25 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: got_metadata_time = time.perf_counter() # Register Remote agent. - remote_agent_name = self.add_remote_agent(metadata, rank) + remote_agent_name = self.add_remote_agent( + metadata, rank, remote_tp_size) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time) - return metadata, remote_agent_name + return remote_agent_name - # Handshake with remote agent-rank0 first to get the tp_size of remote - path = make_zmq_path("tcp", host, port) - logger.debug("Querying master rank metadata on path: %s", path) - rank_to_agent_name: dict[int, str] = {} - metadata, rank_to_agent_name[0] = handshake(path, 0) - - # Handshake only with the other TP remote the current local rank will + # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio - if p_remote_rank > 0: - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", - path, p_remote_rank) - _, rank_to_agent_name[p_remote_rank] = handshake( - path, p_remote_rank) - - return rank_to_agent_name + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", path, + p_remote_rank) + # Remote rank -> agent name. + return {p_remote_rank: handshake(path, p_remote_rank)} def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -632,7 +627,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - tp_size=self.world_size, block_len=self.block_len, attn_backend_name=self.backend_name) ready_event = threading.Event() @@ -646,7 +640,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0) -> str: + remote_tp_rank: int = 0, + remote_tp_size: int = 1) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. @@ -691,9 +686,9 @@ def add_remote_agent(self, return self._remote_agents[engine_id][remote_tp_rank] if engine_id in self._tp_size: - assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + assert self._tp_size[engine_id] == remote_tp_size else: - self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._tp_size[engine_id] = remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name @@ -743,33 +738,31 @@ def add_remote_agent(self, # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - p_remote_tp_rank = self.tp_rank // tp_ratio # Only register the remote's descriptors if current rank pulls from it. - if p_remote_tp_rank == remote_tp_rank: - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ - if not (self.use_mla or is_kv_replicated) else 0 - # Register all remote blocks, but only the corresponding kv heads. - for base_addr in nixl_agent_meta.kv_caches_base_addr: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) - logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and " - "local rank %s", len(blocks_data), engine_id, remote_tp_rank, - self.tp_rank) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.block_len \ + if not (self.use_mla or is_kv_replicated) else 0 + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_tp_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " + "local rank %s", len(blocks_data), engine_id, remote_tp_rank, + self.tp_rank) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs) return remote_agent_name @@ -904,7 +897,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): if fut is None: fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote_host, - meta.remote_port) + meta.remote_port, meta.tp_size) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], @@ -944,13 +937,9 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_block_ids=meta.remote_block_ids, ) - def _read_blocks( - self, - local_block_ids: list[int], - remote_block_ids: list[int], - dst_engine_id: str, - request_id: str, - ): + def _read_blocks(self, local_block_ids: list[int], + remote_block_ids: list[int], dst_engine_id: str, + request_id: str): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the