Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,52 @@ def test_concurrent_load_kv(
return
raise TimeoutError("Took too long to complete async handshake.")

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
"""
vllm_config = create_vllm_config()

# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
worker = connector.connector_worker

# Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096
worker.block_len = worker.slot_size_bytes * worker.block_size
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks

# Metadata with different kv_cache_layout than local worker
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
else "NHD"
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=worker.block_len,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
)

# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)


# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
Expand Down
13 changes: 10 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus

Expand Down Expand Up @@ -73,6 +74,7 @@ class NixlAgentMetadata(
num_blocks: int
block_len: int
attn_backend_name: str
kv_cache_layout: str


@dataclass
Expand Down Expand Up @@ -538,7 +540,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)

self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
Expand Down Expand Up @@ -839,7 +843,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_len=self.block_len,
attn_backend_name=self.backend_name)
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
Expand Down Expand Up @@ -900,8 +905,7 @@ def add_remote_agent(self,
self._tp_size[engine_id] = remote_tp_size
else:
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name

remote_agent_name = self.nixl_wrapper.add_remote_agent(
Expand Down Expand Up @@ -930,6 +934,9 @@ def add_remote_agent(self,
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
remote_block_size //= 2
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
Expand Down