Skip to content

Commit 47b65a5

Browse files
[core] Multi Step Scheduling (#7000)
Co-authored-by: afeldman-nm <[email protected]>
1 parent dad961e commit 47b65a5

File tree

13 files changed

+1004
-34
lines changed

13 files changed

+1004
-34
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,15 @@ steps:
311311
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
312312
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
313313

314+
- label: Multi-step Tests (4 GPUs) # 10min
315+
working_dir: "/vllm-workspace/tests"
316+
num_gpus: 4
317+
source_file_dependencies:
318+
- vllm/
319+
- tests/multi_step/test_correctness.py
320+
commands:
321+
- pytest -v -s multi_step/test_correctness.py
322+
314323
- label: Pipeline Parallelism Test # 23min
315324
working_dir: "/vllm-workspace/tests"
316325
num_gpus: 4

tests/multi_step/__init__.py

Whitespace-only changes.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Test the AsyncLLMEngine with multi-step-decoding
2+
3+
from typing import List
4+
5+
import pytest
6+
7+
from ..utils import RemoteOpenAIServer
8+
9+
MODELS = [
10+
"JackFram/llama-160m",
11+
]
12+
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
13+
NUM_PROMPTS = [10]
14+
15+
DEFAULT_SERVER_ARGS: List[str] = [
16+
"--disable-log-requests",
17+
"--use-v2-block-manager",
18+
"--worker-use-ray",
19+
"--gpu-memory-utilization",
20+
"0.85",
21+
"--swap-space",
22+
"16",
23+
]
24+
25+
26+
async def completions_with_server_args(prompts: List[str], model_name: str,
27+
server_cli_args: List[str]):
28+
29+
outputs = None
30+
with RemoteOpenAIServer(model_name, server_cli_args) as server:
31+
client = server.get_async_client()
32+
outputs = await client.completions.create(model=model_name,
33+
prompt=prompts,
34+
temperature=0,
35+
stream=False,
36+
max_tokens=5)
37+
assert outputs is not None
38+
39+
return outputs
40+
41+
42+
@pytest.mark.parametrize("model", MODELS)
43+
@pytest.mark.parametrize(("tp_size, pp_size"), [
44+
(1, 1),
45+
(2, 2),
46+
])
47+
@pytest.mark.parametrize("eager_mode", [False, True])
48+
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
49+
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
50+
@pytest.mark.asyncio
51+
async def test_multi_step(example_prompts, model: str, tp_size: int,
52+
pp_size: int, eager_mode: int,
53+
num_scheduler_steps: int, num_prompts: int):
54+
55+
prompts = example_prompts
56+
if len(prompts) < num_prompts:
57+
prompts = prompts * ((num_prompts // len(prompts)) + 1)
58+
prompts = prompts[:num_prompts]
59+
assert len(prompts) == num_prompts
60+
61+
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
62+
ms_server_args = DEFAULT_SERVER_ARGS + \
63+
["--num-scheduler-steps", f"{num_scheduler_steps}"]
64+
65+
if eager_mode:
66+
ms_server_args.append("--enforce-eager")
67+
68+
distributed_args = [
69+
"--tensor-parallel-size",
70+
str(tp_size),
71+
"--pipeline-parallel-size",
72+
str(pp_size),
73+
]
74+
75+
ref_completions = await completions_with_server_args(
76+
prompts, model, server_args + distributed_args)
77+
test_completions = await completions_with_server_args(
78+
prompts, model, ms_server_args + distributed_args)
79+
80+
def get_text_generations(completions):
81+
return [x.text for x in completions.choices]
82+
83+
ref_generations = get_text_generations(ref_completions)
84+
test_generations = get_text_generations(test_completions)
85+
assert ref_generations == test_generations

tests/worker/test_model_input.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.worker.embedding_model_runner import (
1111
ModelInputForGPUWithPoolingMetadata)
1212
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
13+
from vllm.worker.multi_step_model_runner import StatefulModelInput
1314

1415

1516
class MockAttentionBackend(AttentionBackend):
@@ -154,3 +155,79 @@ def test_embedding_model_runner_input():
154155
None) == getattr(attn_metadata, field.name, None)
155156
# Pooling metadata is not broadcast.
156157
assert received_model_input.pooling_metadata is None
158+
159+
160+
def test_multi_step_model_runner_input():
161+
sampling_metadata = SamplingMetadata(
162+
["seq_group"],
163+
"selected_token_indices",
164+
"categorized_sample_indices",
165+
"num_prompts",
166+
)
167+
attn_metadata = AttentionMetadata(
168+
num_prefills=1,
169+
num_prefill_tokens=2,
170+
num_decode_tokens=3,
171+
slot_mapping=torch.zeros(1),
172+
)
173+
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
174+
input_tokens=torch.ones(10),
175+
input_positions=torch.ones(10),
176+
sampling_metadata=sampling_metadata,
177+
attn_metadata=attn_metadata)
178+
179+
model_input = StatefulModelInput(
180+
frozen_model_input=frozen_model_input,
181+
is_last_step=True,
182+
is_first_multi_step=False,
183+
current_step=4,
184+
last_sampled_token_ids=torch.ones((10, 1)),
185+
is_multi_step=True,
186+
num_queries=8,
187+
num_seqs=5,
188+
cached_outputs=[],
189+
)
190+
191+
assert isinstance(model_input, StatefulModelInput)
192+
193+
# Test round trip serialization.
194+
tensor_dict = model_input.as_broadcastable_tensor_dict()
195+
attn_backend = MockAttentionBackend()
196+
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
197+
tensor_dict, attn_backend=attn_backend))
198+
199+
receieved_frozen_input = received_model_input.frozen_model_input
200+
201+
# Check that received copy has correct values.
202+
assert isinstance(received_model_input, StatefulModelInput)
203+
assert receieved_frozen_input.input_tokens is not None
204+
assert (receieved_frozen_input.input_tokens ==
205+
frozen_model_input.input_tokens).all()
206+
assert receieved_frozen_input.input_positions is not None
207+
assert (receieved_frozen_input.input_positions ==
208+
frozen_model_input.input_positions).all()
209+
assert receieved_frozen_input.multi_modal_kwargs is None
210+
assert (frozen_model_input.multi_modal_kwargs ==
211+
frozen_model_input.multi_modal_kwargs)
212+
assert receieved_frozen_input.lora_requests is None
213+
assert (receieved_frozen_input.lora_requests ==
214+
frozen_model_input.lora_requests)
215+
assert receieved_frozen_input.lora_mapping is None
216+
assert (
217+
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
218+
for field in dataclasses.fields(AttentionMetadata):
219+
assert getattr(receieved_frozen_input.attn_metadata, field.name,
220+
None) == getattr(attn_metadata, field.name, None)
221+
# For sampling metadata, only selected_token_indices is copied.
222+
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
223+
sampling_metadata.selected_token_indices)
224+
assert receieved_frozen_input.sampling_metadata.seq_groups is None
225+
226+
# check non frozen fields
227+
assert received_model_input.is_last_step == model_input.is_last_step
228+
assert (received_model_input.is_first_multi_step ==
229+
model_input.is_first_multi_step)
230+
assert received_model_input.current_step == model_input.current_step
231+
assert (received_model_input.last_sampled_token_ids ==
232+
model_input.last_sampled_token_ids).all()
233+
assert received_model_input.is_multi_step == model_input.is_multi_step

vllm/engine/arg_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,12 @@ def create_engine_config(self, ) -> EngineConfig:
853853
"in low performance due to small KV cache space. Consider "
854854
"setting --max-model-len to a smaller value.", max_model_len)
855855

856+
if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
857+
self.use_v2_block_manager = True
858+
logger.warning(
859+
"Enabled BlockSpaceManagerV2 because it is "
860+
"required for multi-step (--num-scheduler-steps > 1)")
861+
856862
speculative_config = SpeculativeConfig.maybe_create_spec_config(
857863
target_model_config=model_config,
858864
target_parallel_config=parallel_config,
@@ -881,7 +887,6 @@ def create_engine_config(self, ) -> EngineConfig:
881887
)
882888

883889
if self.num_scheduler_steps > 1:
884-
raise NotImplementedError("Multi-step is not yet supported.")
885890
if speculative_config is not None:
886891
raise ValueError("Speculative decoding is not supported with "
887892
"multi-step (--num-scheduler-steps > 1)")

0 commit comments

Comments
 (0)