Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 168 additions & 58 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
from typing import Optional
from unittest.mock import patch

import pytest

try:
from nixl._api import nixl_agent as NixlWrapper
except ImportError:
NixlWrapper = None

from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
Expand Down Expand Up @@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.

We don't inherit from NixlWrapper because NixlWrapper could be None.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""

AGENT_METADATA = b"fake_agent_metadata"
Expand Down Expand Up @@ -167,7 +161,7 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency

def _nixl_handshake(self, host: str, port: int):
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
Expand All @@ -177,7 +171,7 @@ def _nixl_handshake(self, host: str, port: int):
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks

self.add_remote_agent(
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
Expand All @@ -187,60 +181,176 @@ def _nixl_handshake(self, host: str, port: int):
block_len=self.block_len,
attn_backend_name=self.backend_name,
))


@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.

This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are multiple
transfer states for the same `request_id`, and `get_finished` should handle
it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()

request_id = "req_id"

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
for i in range(4):
return {0: remote_agent_name}


class TestNixlHandshake:

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.

This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()

request_id = "req_id"

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper,
FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
request_id=request_id,
local_block_ids=[
num_xfers + 1, num_xfers + 2, num_xfers + 3
],
kv_transfer_params={
"remote_block_ids":
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host":
"localhost",
"remote_port":
1234,
})
connector.bind_connector_metadata(metadata)

# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"

# Mimic get_finished_kv_transfers in gpu_model_runner.
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break

connector.clear_connector_metadata()

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_async_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that NixlConnector's start_load_kv should be non-blocking."""

vllm_config = create_vllm_config()

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[i + 1, i + 2, i + 3],
metadata.add_new_req(request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [i + 4, i + 5, i + 6],
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
})
connector.bind_connector_metadata(metadata)

dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"

while True:
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break
timeout = 2.5
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
return
raise TimeoutError("Took too long to complete async handshake.")

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_concurrent_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that multiple start_load_kv calls should occur concurrently."""

vllm_config = create_vllm_config()

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
})
connector.bind_connector_metadata(metadata)

timeout = 2.5 * total_reqs
cnt_finished_reqs = 0
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
cnt_finished_reqs += len(done_recving)
if cnt_finished_reqs == total_reqs:
return
raise TimeoutError("Took too long to complete async handshake.")
Loading