1010from .utils import create_dummy_prompt
1111
1212
13+ def get_sequence_groups (scheduler_output ):
14+ return [s .seq_group for s in scheduler_output .scheduled_seq_groups ]
15+
16+
1317def test_scheduler_add_seq_group ():
1418 block_size = 4
1519 scheduler_config = SchedulerConfig (100 , 64 , 1 )
@@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
5761 cache_config .num_cpu_blocks = 8
5862 cache_config .num_gpu_blocks = 8
5963 scheduler = Scheduler (scheduler_config , cache_config , None )
64+ running : List [SequenceGroup ] = []
6065
6166 # Add seq groups to scheduler.
62- running : List [SequenceGroup ] = []
6367 for i in range (num_seq_group ):
6468 _ , seq_group = create_dummy_prompt (str (i ), prompt_length = block_size )
6569 scheduler .add_seq_group (seq_group )
@@ -68,15 +72,15 @@ def test_scheduler_schedule_simple():
6872 # Schedule seq groups prompts.
6973 num_tokens = block_size * num_seq_group
7074 seq_group_meta , out = scheduler .schedule ()
71- assert set (out . scheduled_seq_groups ) == set (running )
75+ assert set (get_sequence_groups ( out ) ) == set (running )
7276 assert out .num_batched_tokens == num_tokens
7377 assert (not out .blocks_to_copy and not out .blocks_to_swap_in
7478 and not out .blocks_to_swap_out )
7579 assert len (seq_group_meta ) == num_seq_group
7680
7781 # Schedule seq groups generation.
7882 seq_group_meta , out = scheduler .schedule ()
79- assert set (out . scheduled_seq_groups ) == set (running )
83+ assert set (get_sequence_groups ( out ) ) == set (running )
8084 assert out .num_batched_tokens == num_seq_group
8185 assert (not out .blocks_to_copy and not out .blocks_to_swap_in
8286 and not out .blocks_to_swap_out )
@@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
100104
101105 # Schedule seq groups prompts.
102106 seq_group_meta , out = scheduler .schedule ()
103- assert out . scheduled_seq_groups == [seq_group_a , seq_group_b ]
107+ assert get_sequence_groups ( out ) == [seq_group_a , seq_group_b ]
104108 assert out .num_batched_tokens == block_size * 2 # seq_a and seq_b
105109 assert (not out .blocks_to_copy and not out .blocks_to_swap_in
106110 and not out .blocks_to_swap_out )
@@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
115119
116120 # Schedule seq groups generation and preempt seq group b.
117121 seq_group_meta , out = scheduler .schedule ()
118- assert out . scheduled_seq_groups == [seq_group_a ]
122+ assert get_sequence_groups ( out ) == [seq_group_a ]
119123 assert out .num_batched_tokens == 1
120124 assert (not out .blocks_to_copy and not out .blocks_to_swap_in
121125 and not out .blocks_to_swap_out )
@@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
125129 # Abort seq group a. Re-schedule seq group b prompt with recomputation.
126130 scheduler .abort_seq_group ("1" )
127131 seq_group_meta , out = scheduler .schedule ()
128- assert out . scheduled_seq_groups == [seq_group_b ]
132+ assert get_sequence_groups ( out ) == [seq_group_b ]
129133 assert out .num_batched_tokens == 5 # 4 prompt + 1 generation.
130134 assert (not out .blocks_to_copy and not out .blocks_to_swap_in
131135 and not out .blocks_to_swap_out )
@@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
155159
156160 # Schedule seq groups prompts.
157161 _ , out = scheduler .schedule ()
158- assert set (out . scheduled_seq_groups ) == set ([all_seq_groups [0 ]])
162+ assert set (get_sequence_groups ( out ) ) == set ([all_seq_groups [0 ]])
159163
160164 # Schedule seq groups generation.
161165 _ , out = scheduler .schedule ()
162- assert set (out . scheduled_seq_groups ) == set ([all_seq_groups [0 ]])
166+ assert set (get_sequence_groups ( out ) ) == set ([all_seq_groups [0 ]])
163167
164168 # Append 2 more seq group
165169 scheduler .add_seq_group (all_seq_groups [1 ])
@@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
169173 # Only 1 seq group should be scheduled since max_seq_group is 2
170174 # and one is prompting.
171175 _ , out = scheduler .schedule ()
172- assert set (out . scheduled_seq_groups ) == set ([all_seq_groups [1 ]])
176+ assert set (get_sequence_groups ( out ) ) == set ([all_seq_groups [1 ]])
173177
174178
175179def test_scheduler_delay_factor ():
0 commit comments