Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Optional
from unittest.mock import patch

import pytest

from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
Expand Down Expand Up @@ -161,7 +163,8 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency

def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
Expand All @@ -177,10 +180,10 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
tp_size=1,
block_len=self.block_len,
attn_backend_name=self.backend_name,
))
),
remote_tp_size=remote_tp_size)
return {0: remote_agent_name}


Expand Down Expand Up @@ -233,6 +236,8 @@ def test_multi_xfer_one_engine(
"localhost",
"remote_port":
1234,
"remote_tp_size":
1,
})
connector.bind_connector_metadata(metadata)

Expand All @@ -259,13 +264,23 @@ def test_multi_xfer_one_engine(
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
(1, 1),
(2, 1),
(4, 2),
(4, 4),
])
def test_async_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
self,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size):
"""Test that NixlConnector's start_load_kv should be non-blocking."""

vllm_config = create_vllm_config()
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
Expand All @@ -280,6 +295,7 @@ def test_async_load_kv(
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
})
connector.bind_connector_metadata(metadata)

Expand Down Expand Up @@ -329,6 +345,7 @@ def test_concurrent_load_kv(
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)

Expand Down
109 changes: 49 additions & 60 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
tp_size: int
block_len: int
attn_backend_name: str

Expand All @@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: EngineId
remote_engine_id: str
tp_size: int


class NixlConnectorMetadata(KVConnectorMetadata):
Expand All @@ -93,6 +93,8 @@ def add_new_req(
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)


Expand Down Expand Up @@ -317,7 +319,7 @@ def request_finished(
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
tp_size=self.vllm_config.parallel_config.tensor_parallel_size)


class NixlConnectorWorker:
Expand Down Expand Up @@ -460,7 +462,8 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))

def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""

start_time = time.perf_counter()
Expand All @@ -469,7 +472,7 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.

def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
def handshake(path: str, rank: int) -> str:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
Expand All @@ -479,33 +482,25 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
got_metadata_time = time.perf_counter()

# Register Remote agent.
remote_agent_name = self.add_remote_agent(metadata, rank)
remote_agent_name = self.add_remote_agent(
metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter()

logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return metadata, remote_agent_name
return remote_agent_name

# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
rank_to_agent_name: dict[int, str] = {}
metadata, rank_to_agent_name[0] = handshake(path, 0)

# Handshake only with the other TP remote the current local rank will
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_, rank_to_agent_name[p_remote_rank] = handshake(
path, p_remote_rank)

return rank_to_agent_name
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)
# Remote rank -> agent name.
return {p_remote_rank: handshake(path, p_remote_rank)}

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
Expand Down Expand Up @@ -632,7 +627,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
Expand All @@ -646,7 +640,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0) -> str:
remote_tp_rank: int = 0,
remote_tp_size: int = 1) -> str:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
Expand Down Expand Up @@ -691,9 +686,9 @@ def add_remote_agent(self,
return self._remote_agents[engine_id][remote_tp_rank]

if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
assert self._tp_size[engine_id] == remote_tp_size
else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size
self._tp_size[engine_id] = remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
Expand Down Expand Up @@ -743,33 +738,31 @@ def add_remote_agent(self,
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank = self.tp_rank // tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
if p_remote_tp_rank == remote_tp_rank:
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)

return remote_agent_name

Expand Down Expand Up @@ -904,7 +897,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
meta.remote_port)
meta.remote_port, meta.tp_size)
self._handshake_futures[remote_engine_id] = fut

def done_callback(f: Future[dict[int, str]],
Expand Down Expand Up @@ -944,13 +937,9 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
remote_block_ids=meta.remote_block_ids,
)

def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
):
def _read_blocks(self, local_block_ids: list[int],
remote_block_ids: list[int], dst_engine_id: str,
request_id: str):
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
Expand Down