Skip to content

Commit b51c1cc

Browse files
authored
[2/N] Chunked prefill data update (#3538)
1 parent ce567a2 commit b51c1cc

File tree

11 files changed

+272
-76
lines changed

11 files changed

+272
-76
lines changed

benchmarks/benchmark_latency.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
2626
kv_cache_dtype=args.kv_cache_dtype,
2727
device=args.device,
2828
ray_workers_use_nsight=args.ray_workers_use_nsight,
29-
download_dir=args.download_dir)
29+
enable_chunked_prefill=args.enable_chunked_prefill,
30+
download_dir=args.download_dir,
31+
block_size=args.block_size)
3032

3133
sampling_params = SamplingParams(
3234
n=args.n,
@@ -145,6 +147,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
145147
default="cuda",
146148
choices=["cuda"],
147149
help='device type for vLLM execution, supporting CUDA only currently.')
150+
parser.add_argument('--block-size',
151+
type=int,
152+
default=16,
153+
help='block size of key/value cache')
154+
parser.add_argument(
155+
'--enable-chunked-prefill',
156+
type=bool,
157+
default=False,
158+
help='If True, the prefill requests can be chunked based on the '
159+
'max_num_batched_tokens')
148160
parser.add_argument(
149161
"--ray-workers-use-nsight",
150162
action='store_true',

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def __init__(
256256
dtype: str = "half",
257257
disable_log_stats: bool = True,
258258
tensor_parallel_size: int = 1,
259+
block_size: int = 16,
260+
enable_chunked_prefill: bool = False,
259261
**kwargs,
260262
) -> None:
261263
self.model = LLM(
@@ -266,6 +268,8 @@ def __init__(
266268
swap_space=0,
267269
disable_log_stats=disable_log_stats,
268270
tensor_parallel_size=tensor_parallel_size,
271+
block_size=block_size,
272+
enable_chunked_prefill=enable_chunked_prefill,
269273
**kwargs,
270274
)
271275

tests/core/test_scheduler.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from .utils import create_dummy_prompt
1111

1212

13+
def get_sequence_groups(scheduler_output):
14+
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
15+
16+
1317
def test_scheduler_add_seq_group():
1418
block_size = 4
1519
scheduler_config = SchedulerConfig(100, 64, 1)
@@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
5761
cache_config.num_cpu_blocks = 8
5862
cache_config.num_gpu_blocks = 8
5963
scheduler = Scheduler(scheduler_config, cache_config, None)
64+
running: List[SequenceGroup] = []
6065

6166
# Add seq groups to scheduler.
62-
running: List[SequenceGroup] = []
6367
for i in range(num_seq_group):
6468
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
6569
scheduler.add_seq_group(seq_group)
@@ -68,15 +72,15 @@ def test_scheduler_schedule_simple():
6872
# Schedule seq groups prompts.
6973
num_tokens = block_size * num_seq_group
7074
seq_group_meta, out = scheduler.schedule()
71-
assert set(out.scheduled_seq_groups) == set(running)
75+
assert set(get_sequence_groups(out)) == set(running)
7276
assert out.num_batched_tokens == num_tokens
7377
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
7478
and not out.blocks_to_swap_out)
7579
assert len(seq_group_meta) == num_seq_group
7680

7781
# Schedule seq groups generation.
7882
seq_group_meta, out = scheduler.schedule()
79-
assert set(out.scheduled_seq_groups) == set(running)
83+
assert set(get_sequence_groups(out)) == set(running)
8084
assert out.num_batched_tokens == num_seq_group
8185
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
8286
and not out.blocks_to_swap_out)
@@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
100104

101105
# Schedule seq groups prompts.
102106
seq_group_meta, out = scheduler.schedule()
103-
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
107+
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
104108
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
105109
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
106110
and not out.blocks_to_swap_out)
@@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
115119

116120
# Schedule seq groups generation and preempt seq group b.
117121
seq_group_meta, out = scheduler.schedule()
118-
assert out.scheduled_seq_groups == [seq_group_a]
122+
assert get_sequence_groups(out) == [seq_group_a]
119123
assert out.num_batched_tokens == 1
120124
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
121125
and not out.blocks_to_swap_out)
@@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
125129
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
126130
scheduler.abort_seq_group("1")
127131
seq_group_meta, out = scheduler.schedule()
128-
assert out.scheduled_seq_groups == [seq_group_b]
132+
assert get_sequence_groups(out) == [seq_group_b]
129133
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
130134
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
131135
and not out.blocks_to_swap_out)
@@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
155159

156160
# Schedule seq groups prompts.
157161
_, out = scheduler.schedule()
158-
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
162+
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
159163

160164
# Schedule seq groups generation.
161165
_, out = scheduler.schedule()
162-
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
166+
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
163167

164168
# Append 2 more seq group
165169
scheduler.add_seq_group(all_seq_groups[1])
@@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
169173
# Only 1 seq group should be scheduled since max_seq_group is 2
170174
# and one is prompting.
171175
_, out = scheduler.schedule()
172-
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
176+
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
173177

174178

175179
def test_scheduler_delay_factor():

tests/test_sequence.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

3-
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
3+
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
4+
SequenceOutput)
45

56

67
@pytest.fixture
@@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
4849
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
4950
assert sampler_output1 == sampler_output2
5051
assert sampler_output1 != sampler_output3
52+
53+
54+
def test_sequence_data_prefill():
55+
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
56+
assert seq_data.get_num_uncomputed_tokens() == 4
57+
assert seq_data.get_num_computed_tokens() == 0
58+
# advance by 2
59+
seq_data.update_num_computed_tokens(2)
60+
assert seq_data.get_num_uncomputed_tokens() == 2
61+
assert seq_data.get_num_computed_tokens() == 2
62+
63+
# advance by 1
64+
seq_data.update_num_computed_tokens(1)
65+
assert seq_data.get_num_uncomputed_tokens() == 1
66+
assert seq_data.get_num_computed_tokens() == 3
67+
68+
# append tokens and reset, simulating recompute
69+
seq_data.append_token_id(1, logprob=0.0)
70+
seq_data.reset_num_computed_tokens()
71+
assert seq_data.get_num_uncomputed_tokens() == 5
72+
assert seq_data.get_num_computed_tokens() == 0

tests/worker/test_model_runner.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
1818
# make sure all tokens fit into one block
1919
prompt_len = i % (model_runner.block_size - 1) + 1
2020
prompt_lens.append(prompt_len)
21-
seq_data = list(range(prompt_len))
22-
seq_group_metadata_list.append(
23-
SequenceGroupMetadata(
24-
request_id=f"test_{i}",
25-
is_prompt=True,
26-
seq_data={0: SequenceData(seq_data)},
27-
sampling_params=SamplingParams(temperature=0),
28-
block_tables=block_tables,
29-
))
21+
seq_data = SequenceData(list(range(prompt_len)))
22+
seq_group_metadata = SequenceGroupMetadata(
23+
request_id=f"test_{i}",
24+
is_prompt=True,
25+
seq_data={0: seq_data},
26+
sampling_params=SamplingParams(temperature=0),
27+
block_tables=block_tables,
28+
)
29+
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
30+
seq_group_metadata_list.append(seq_group_metadata)
3031

3132
expected_selected_token_indices = []
3233
selected_token_start_idx = 0
@@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
131132
prompt_len = i % (model_runner.block_size - 1) + 1
132133
prompt_lens.append(prompt_len)
133134
seq_data = list(range(prompt_len))
134-
seq_group_metadata_list.append(
135-
SequenceGroupMetadata(
136-
request_id=f"test_{i}",
137-
is_prompt=False,
138-
seq_data={0: SequenceData(seq_data)},
139-
sampling_params=SamplingParams(temperature=0),
140-
block_tables={0: [1]},
141-
))
135+
seq_data = SequenceData(seq_data)
136+
seq_group_metadata = SequenceGroupMetadata(
137+
request_id=f"test_{i}",
138+
is_prompt=False,
139+
seq_data={0: seq_data},
140+
sampling_params=SamplingParams(temperature=0),
141+
block_tables={0: [1]},
142+
)
143+
assert seq_group_metadata.token_chunk_size == 1
144+
seq_group_metadata_list.append(seq_group_metadata)
142145

143146
input_tokens, input_positions, attn_metadata, _, _, _ = (
144147
model_runner._prepare_decode(seq_group_metadata_list))

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,8 @@ class SchedulerConfig:
533533
delay_factor: Apply a delay (of delay factor multiplied by previous
534534
prompt latency) before scheduling next prompt.
535535
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
536+
enable_chunked_prefill: If True, prefill requests can be chunked based
537+
on the remaining max_num_batched_tokens.
536538
"""
537539

538540
def __init__(
@@ -542,6 +544,7 @@ def __init__(
542544
max_model_len: int,
543545
use_v2_block_manager: bool = False,
544546
delay_factor: float = 0.0,
547+
enable_chunked_prefill: bool = False,
545548
) -> None:
546549
if max_num_batched_tokens is not None:
547550
self.max_num_batched_tokens = max_num_batched_tokens
@@ -553,6 +556,7 @@ def __init__(
553556
self.max_model_len = max_model_len
554557
self.delay_factor = delay_factor
555558
self.use_v2_block_manager = use_v2_block_manager
559+
self.chunked_prefill_enabled = enable_chunked_prefill
556560
self._verify_args()
557561

558562
def _verify_args(self) -> None:

0 commit comments

Comments
 (0)