|
30 | 30 | from vllm.logger import init_logger |
31 | 31 | from vllm.platforms import _Backend, current_platform |
32 | 32 | from vllm.utils import make_zmq_path, make_zmq_socket |
| 33 | +from vllm.v1.attention.backends.utils import get_kv_cache_layout |
33 | 34 | from vllm.v1.core.sched.output import SchedulerOutput |
34 | 35 | from vllm.v1.request import RequestStatus |
35 | 36 |
|
@@ -73,6 +74,7 @@ class NixlAgentMetadata( |
73 | 74 | num_blocks: int |
74 | 75 | block_len: int |
75 | 76 | attn_backend_name: str |
| 77 | + kv_cache_layout: str |
76 | 78 |
|
77 | 79 |
|
78 | 80 | @dataclass |
@@ -538,7 +540,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): |
538 | 540 | attn_backend = backend_name_to_enum(self.backend_name) |
539 | 541 | self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 |
540 | 542 | self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 |
| 543 | + self.kv_cache_layout = get_kv_cache_layout() |
541 | 544 | logger.debug("Detected attention backend %s", self.backend_name) |
| 545 | + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) |
542 | 546 |
|
543 | 547 | self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} |
544 | 548 | # With heterogeneous TP, P must wait for all assigned D TP workers to |
@@ -839,7 +843,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): |
839 | 843 | kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], |
840 | 844 | num_blocks=self.num_blocks, |
841 | 845 | block_len=self.block_len, |
842 | | - attn_backend_name=self.backend_name) |
| 846 | + attn_backend_name=self.backend_name, |
| 847 | + kv_cache_layout=self.kv_cache_layout) |
843 | 848 | ready_event = threading.Event() |
844 | 849 | self._nixl_handshake_listener_t = threading.Thread( |
845 | 850 | target=self._nixl_handshake_listener, |
@@ -900,8 +905,7 @@ def add_remote_agent(self, |
900 | 905 | self._tp_size[engine_id] = remote_tp_size |
901 | 906 | else: |
902 | 907 | assert self._tp_size[engine_id] == remote_tp_size |
903 | | - # We may eventually enable this after asserting equality in cache |
904 | | - # layout and close outputs. |
| 908 | + # TODO We may eventually want to skip enforcing the same attn backend. |
905 | 909 | assert nixl_agent_meta.attn_backend_name == self.backend_name |
906 | 910 |
|
907 | 911 | remote_agent_name = self.nixl_wrapper.add_remote_agent( |
@@ -930,6 +934,9 @@ def add_remote_agent(self, |
930 | 934 | if self._use_flashinfer: |
931 | 935 | # Account for joint KV in FlashInfer. |
932 | 936 | remote_block_size //= 2 |
| 937 | + if tp_ratio > 1: |
| 938 | + # Heterogeneous TP expects same kv_cache_layout. |
| 939 | + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout |
933 | 940 |
|
934 | 941 | assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( |
935 | 942 | "Remote P worker KV layer cache must be of shape [2, N, " |
|
0 commit comments