Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit fb41934

Browse files
stephanie-wangStephanie
authored andcommitted
[Core] Refactor Worker and ModelRunner to consolidate control plane communication (vllm-project#5408)
Signed-off-by: Stephanie Wang <[email protected]> Signed-off-by: Stephanie <[email protected]> Co-authored-by: Stephanie <[email protected]>
1 parent f9775e9 commit fb41934

29 files changed

+1212
-577
lines changed

tests/worker/test_model_input.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import dataclasses
2+
from typing import List, Tuple, Type
3+
4+
import torch
5+
6+
from vllm.attention import AttentionMetadata
7+
from vllm.attention.backends.abstract import AttentionBackend
8+
from vllm.model_executor import SamplingMetadata
9+
from vllm.model_executor.pooling_metadata import PoolingMetadata
10+
from vllm.worker.embedding_model_runner import (
11+
ModelInputForGPUWithPoolingMetadata)
12+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
13+
14+
15+
class MockAttentionBackend(AttentionBackend):
16+
17+
@staticmethod
18+
def get_name() -> str:
19+
raise NotImplementedError
20+
21+
@staticmethod
22+
def get_impl_cls():
23+
raise NotImplementedError
24+
25+
@staticmethod
26+
def get_metadata_cls() -> Type["AttentionMetadata"]:
27+
return AttentionMetadata
28+
29+
@staticmethod
30+
def get_kv_cache_shape(
31+
num_blocks: int,
32+
block_size: int,
33+
num_kv_heads: int,
34+
head_size: int,
35+
) -> Tuple[int, ...]:
36+
raise NotImplementedError
37+
38+
@staticmethod
39+
def swap_blocks(
40+
src_kv_cache: torch.Tensor,
41+
dst_kv_cache: torch.Tensor,
42+
src_to_dst: torch.Tensor,
43+
) -> None:
44+
pass
45+
46+
@staticmethod
47+
def copy_blocks(
48+
kv_caches: List[torch.Tensor],
49+
src_to_dists: torch.Tensor,
50+
) -> None:
51+
pass
52+
53+
54+
def test_model_runner_input():
55+
sampling_metadata = SamplingMetadata(
56+
["seq_group"],
57+
"selected_token_indices",
58+
"categorized_sample_indices",
59+
"num_prompts",
60+
)
61+
attn_metadata = AttentionMetadata(
62+
num_prefills=1,
63+
num_prefill_tokens=2,
64+
num_decode_tokens=3,
65+
slot_mapping=torch.zeros(1),
66+
)
67+
model_input = ModelInputForGPUWithSamplingMetadata(
68+
input_tokens=torch.ones(10),
69+
input_positions=torch.ones(10),
70+
sampling_metadata=sampling_metadata,
71+
attn_metadata=attn_metadata)
72+
73+
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
74+
75+
# Test round trip serialization.
76+
tensor_dict = model_input.as_broadcastable_tensor_dict()
77+
attn_backend = MockAttentionBackend()
78+
received_model_input = (
79+
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
80+
tensor_dict, attn_backend=attn_backend))
81+
# Check that received copy has correct values.
82+
assert isinstance(received_model_input,
83+
ModelInputForGPUWithSamplingMetadata)
84+
assert received_model_input.input_tokens is not None
85+
assert (
86+
received_model_input.input_tokens == model_input.input_tokens).all()
87+
assert received_model_input.input_positions is not None
88+
assert (received_model_input.input_positions == model_input.input_positions
89+
).all()
90+
assert received_model_input.multi_modal_kwargs is None
91+
assert (received_model_input.multi_modal_kwargs ==
92+
model_input.multi_modal_kwargs)
93+
assert received_model_input.lora_requests is None
94+
assert received_model_input.lora_requests == model_input.lora_requests
95+
assert received_model_input.lora_mapping is None
96+
assert received_model_input.lora_mapping == model_input.lora_mapping
97+
for field in dataclasses.fields(AttentionMetadata):
98+
assert getattr(received_model_input.attn_metadata, field.name,
99+
None) == getattr(attn_metadata, field.name, None)
100+
# For sampling metadata, only selected_token_indices is copied.
101+
assert (received_model_input.sampling_metadata.selected_token_indices ==
102+
sampling_metadata.selected_token_indices)
103+
assert received_model_input.sampling_metadata.seq_groups is None
104+
105+
106+
def test_embedding_model_runner_input():
107+
pooling_metadata = PoolingMetadata(
108+
seq_groups=[[0]],
109+
seq_data={},
110+
prompt_lens=[1],
111+
)
112+
attn_metadata = AttentionMetadata(
113+
num_prefills=1,
114+
num_prefill_tokens=2,
115+
num_decode_tokens=3,
116+
slot_mapping=torch.zeros(1),
117+
)
118+
model_input = ModelInputForGPUWithPoolingMetadata(
119+
input_tokens=torch.ones(10),
120+
input_positions=torch.ones(10),
121+
pooling_metadata=pooling_metadata,
122+
attn_metadata=attn_metadata)
123+
124+
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
125+
126+
# Test round trip serialization.
127+
tensor_dict = model_input.as_broadcastable_tensor_dict()
128+
attn_backend = MockAttentionBackend()
129+
received_model_input = (
130+
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
131+
tensor_dict, attn_backend=attn_backend))
132+
# Check that received copy has correct values.
133+
assert isinstance(received_model_input,
134+
ModelInputForGPUWithPoolingMetadata)
135+
assert received_model_input.input_tokens is not None
136+
assert (
137+
received_model_input.input_tokens == model_input.input_tokens).all()
138+
assert received_model_input.input_positions is not None
139+
assert (received_model_input.input_positions == model_input.input_positions
140+
).all()
141+
assert received_model_input.multi_modal_kwargs is None
142+
assert (received_model_input.multi_modal_kwargs ==
143+
model_input.multi_modal_kwargs)
144+
assert received_model_input.lora_requests is None
145+
assert received_model_input.lora_requests == model_input.lora_requests
146+
assert received_model_input.lora_mapping is None
147+
assert received_model_input.lora_mapping == model_input.lora_mapping
148+
for field in dataclasses.fields(AttentionMetadata):
149+
assert getattr(received_model_input.attn_metadata, field.name,
150+
None) == getattr(attn_metadata, field.name, None)
151+
# Pooling metadata is not broadcast.
152+
assert received_model_input.pooling_metadata is None

tests/worker/test_model_runner.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ def test_prepare_prompt(batch_size):
6666
expected_selected_token_indices.append(selected_token_start_idx +
6767
seq_len - 1)
6868
selected_token_start_idx += seq_len
69-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
69+
model_input = model_runner._prepare_model_input_tensors(
70+
seq_group_metadata_list)
7071
input_tokens = model_input.input_tokens
7172
input_positions = model_input.input_positions
7273
attn_metadata = model_input.attn_metadata
7374
return_seq_lens = model_input.seq_lens
74-
slot_mapping = model_input.slot_mapping
75+
slot_mapping = attn_metadata.slot_mapping
7576
assert return_seq_lens == seq_lens
7677
assert len(slot_mapping) == len(input_tokens)
7778

@@ -179,10 +180,11 @@ def test_prepare_decode_cuda_graph(batch_size):
179180
assert seq_group_metadata.token_chunk_size == 1
180181
seq_group_metadata_list.append(seq_group_metadata)
181182

182-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
183+
model_input = model_runner._prepare_model_input_tensors(
184+
seq_group_metadata_list)
183185
input_tokens, input_positions, attn_metadata, slot_mapping = (
184186
model_input.input_tokens, model_input.input_positions,
185-
model_input.attn_metadata, model_input.slot_mapping)
187+
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
186188
assert len(slot_mapping) == len(input_tokens)
187189

188190
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
@@ -264,32 +266,29 @@ def test_empty_seq_group():
264266
enforce_eager=False,
265267
)
266268
seq_group_metadata_list: List[SequenceGroupMetadata] = []
267-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
268-
input_tokens, input_positions, attn_metadata, slot_mapping = (
269+
model_input = model_runner._prepare_model_input_tensors(
270+
seq_group_metadata_list)
271+
input_tokens, input_positions, attn_metadata = (
269272
model_input.input_tokens,
270273
model_input.input_positions,
271274
model_input.attn_metadata,
272-
model_input.slot_mapping,
273275
)
274-
assert len(input_tokens) == 0
275-
assert len(input_positions) == 0
276+
assert input_tokens is None
277+
assert input_positions is None
276278
assert attn_metadata is None
277-
assert len(slot_mapping) == 0
278-
279-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
280-
(input_tokens, input_positions, attn_metadata, slot_mapping,
281-
return_seq_lens) = (
282-
model_input.input_tokens,
283-
model_input.input_positions,
284-
model_input.attn_metadata,
285-
model_input.slot_mapping,
286-
model_input.seq_lens,
287-
)
288-
assert len(input_tokens) == 0
289-
assert len(input_positions) == 0
279+
280+
model_input = model_runner._prepare_model_input_tensors(
281+
seq_group_metadata_list)
282+
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
283+
model_input.input_tokens,
284+
model_input.input_positions,
285+
model_input.attn_metadata,
286+
model_input.seq_lens,
287+
)
288+
assert input_tokens is None
289+
assert input_positions is None
290290
assert attn_metadata is None
291-
assert len(slot_mapping) == 0
292-
assert len(return_seq_lens) == 0
291+
assert return_seq_lens is None
293292

294293

295294
@pytest.fixture
@@ -358,8 +357,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
358357
seq_group_metadata_list.append(seq_group_metadata)
359358
decode_metadata_list.append(seq_group_metadata)
360359

361-
(input_tokens, input_positions, attn_metadata, _, _, _,
362-
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
360+
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
361+
(input_tokens, input_positions, attn_metadata) = (
362+
model_input.input_tokens,
363+
model_input.input_positions,
364+
model_input.attn_metadata,
365+
)
363366

364367
prefill_meta_actual = attn_metadata.prefill_metadata
365368
decode_meta_actual = attn_metadata.decode_metadata
@@ -372,7 +375,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
372375

373376
# Verify attn metadata is consistent. We don't need to test individual
374377
# values here because they are tested above.
375-
attn_metadata = model_runner._prepare_model_input(
378+
attn_metadata = model_runner._prepare_model_input_tensors(
376379
seq_group_metadata_list).attn_metadata
377380

378381
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),

vllm/attention/backends/abstract.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:
2121

2222
@staticmethod
2323
@abstractmethod
24-
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
24+
def get_metadata_cls() -> Type["AttentionMetadata"]:
2525
raise NotImplementedError
2626

27+
@classmethod
28+
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
29+
return cls.get_metadata_cls()(*args, **kwargs)
30+
2731
@staticmethod
2832
@abstractmethod
2933
def get_kv_cache_shape(

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
9090
return BlocksparseFlashAttentionImpl
9191

9292
@staticmethod
93-
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
94-
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
93+
def get_metadata_cls() -> Type["AttentionMetadata"]:
94+
return BlocksparseFlashAttentionMetadata
9595

9696
@staticmethod
9797
def get_kv_cache_shape(

vllm/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
2525
return FlashAttentionImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
29-
return FlashAttentionMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["AttentionMetadata"]:
29+
return FlashAttentionMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]:
2222
return FlashInferImpl
2323

2424
@staticmethod
25-
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
26-
return FlashInferMetadata(*args, **kwargs)
25+
def get_metadata_cls() -> Type["AttentionMetadata"]:
26+
return FlashInferMetadata
2727

2828
@staticmethod
2929
def get_kv_cache_shape(

vllm/attention/backends/ipex_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
2525
return IpexAttnBackendImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
29-
return IpexAttnMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
29+
return IpexAttnMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
1616
return PallasAttentionBackendImpl
1717

1818
@staticmethod
19-
def make_metadata(*args, **kwargs) -> "PallasMetadata":
20-
return PallasMetadata(*args, **kwargs)
19+
def get_metadata_cls() -> Type["PallasMetadata"]:
20+
return PallasMetadata
2121

2222
@staticmethod
2323
def get_kv_cache_shape(

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
2525
return ROCmFlashAttentionImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
29-
return ROCmFlashAttentionMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["AttentionMetadata"]:
29+
return ROCmFlashAttentionMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/torch_sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
3131
return TorchSDPABackendImpl
3232

3333
@staticmethod
34-
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
35-
return TorchSDPAMetadata(*args, **kwargs)
34+
def get_metadata_cls() -> Type["AttentionMetadata"]:
35+
return TorchSDPAMetadata
3636

3737
@staticmethod
3838
def get_kv_cache_shape(

0 commit comments

Comments
 (0)