@@ -100,21 +100,19 @@ def test_models(
100100 else :
101101 hf_outputs = None
102102
103- if model not in V0_UNSUPPORTED_MODELS :
104- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
105- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
106- example_prompts , max_tokens , num_logprobs )
107- else :
108- vllm_v0_outputs = None
103+ with monkeypatch .context () as m :
104+ m .setenv ("VLLM_USE_V1" , "0" )
105+ if model not in V0_UNSUPPORTED_MODELS :
106+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
107+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
108+ example_prompts , max_tokens , num_logprobs )
109+ else :
110+ vllm_v0_outputs = None
109111
110112 if model in V1_SUPPORTED_MODELS :
111- with monkeypatch .context () as m :
112- m .setenv ("VLLM_USE_V1" , "1" )
113- with vllm_runner (model ,
114- max_num_seqs = MAX_NUM_SEQS ,
115- enable_prefix_caching = False ) as vllm_model :
116- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
117- example_prompts , max_tokens , num_logprobs )
113+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
114+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
115+ example_prompts , max_tokens , num_logprobs )
118116 else :
119117 vllm_v1_outputs = None
120118
@@ -137,7 +135,7 @@ def test_models(
137135 )
138136
139137
140- @pytest .mark .parametrize ("model" , SSM_MODELS + HYBRID_MODELS )
138+ @pytest .mark .parametrize ("model" , [ SSM_MODELS [ 0 ], HYBRID_MODELS [ 0 ]] )
141139@pytest .mark .parametrize ("max_tokens" , [64 ])
142140@pytest .mark .parametrize ("num_logprobs" , [5 ])
143141def test_batching (
@@ -147,10 +145,6 @@ def test_batching(
147145 max_tokens : int ,
148146 num_logprobs : int ,
149147) -> None :
150- if model in V0_UNSUPPORTED_MODELS :
151- pytest .skip (
152- f"Unsupported V0 Engine. Skipping `test_batching` on { model } ." )
153-
154148 try :
155149 model_info = HF_EXAMPLE_MODELS .find_hf_info (model )
156150 model_info .check_available_online (on_fail = "skip" )
@@ -188,29 +182,32 @@ def test_chunked_prefill(
188182 max_tokens : int ,
189183 num_logprobs : int ,
190184 chunked_prefill_token_size : int ,
185+ monkeypatch ,
191186) -> None :
192187 max_num_seqs = chunked_prefill_token_size
193188 max_num_batched_tokens = chunked_prefill_token_size
194189
195- with vllm_runner (model ,
196- enable_chunked_prefill = True ,
197- max_num_batched_tokens = max_num_batched_tokens ,
198- max_num_seqs = max_num_seqs ) as vllm_model :
199- chunked = vllm_model .generate_greedy_logprobs (example_prompts ,
200- max_tokens , num_logprobs )
190+ with monkeypatch .context () as m :
191+ m .setenv ("VLLM_USE_V1" , "0" )
192+ with vllm_runner (model ,
193+ enable_chunked_prefill = True ,
194+ max_num_batched_tokens = max_num_batched_tokens ,
195+ max_num_seqs = max_num_seqs ) as vllm_model :
196+ chunked = vllm_model .generate_greedy_logprobs (
197+ example_prompts , max_tokens , num_logprobs )
201198
202- with vllm_runner (model ,
203- enable_chunked_prefill = False ,
204- max_num_seqs = max_num_seqs ) as vllm_model :
205- non_chunked = vllm_model .generate_greedy_logprobs (
206- example_prompts , max_tokens , num_logprobs )
199+ with vllm_runner (model ,
200+ enable_chunked_prefill = False ,
201+ max_num_seqs = max_num_seqs ) as vllm_model :
202+ non_chunked = vllm_model .generate_greedy_logprobs (
203+ example_prompts , max_tokens , num_logprobs )
207204
208- check_logprobs_close (
209- outputs_0_lst = chunked ,
210- outputs_1_lst = non_chunked ,
211- name_0 = "chunked" ,
212- name_1 = "non_chunked" ,
213- )
205+ check_logprobs_close (
206+ outputs_0_lst = chunked ,
207+ outputs_1_lst = non_chunked ,
208+ name_0 = "chunked" ,
209+ name_1 = "non_chunked" ,
210+ )
214211
215212
216213@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -281,25 +278,29 @@ def test_models_preemption_recompute(
281278 example_prompts ,
282279 model : str ,
283280 max_tokens : int ,
281+ monkeypatch ,
284282) -> None :
285283 """
286284 Tests that outputs are identical with and w/o preemptions (recompute).
287285 """
288- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
289- scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
290- scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
291- preempt_vllm_outputs = vllm_model .generate_greedy (
292- example_prompts , max_tokens )
293-
294- scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
295- vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
296-
297- check_outputs_equal (
298- outputs_0_lst = preempt_vllm_outputs ,
299- outputs_1_lst = vllm_outputs ,
300- name_0 = "vllm_preepmtions" ,
301- name_1 = "vllm" ,
302- )
286+ with monkeypatch .context () as m :
287+ m .setenv ("VLLM_USE_V1" , "0" )
288+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
289+ scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
290+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
291+ preempt_vllm_outputs = vllm_model .generate_greedy (
292+ example_prompts , max_tokens )
293+
294+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
295+ vllm_outputs = vllm_model .generate_greedy (example_prompts ,
296+ max_tokens )
297+
298+ check_outputs_equal (
299+ outputs_0_lst = preempt_vllm_outputs ,
300+ outputs_1_lst = vllm_outputs ,
301+ name_0 = "vllm_preepmtions" ,
302+ name_1 = "vllm" ,
303+ )
303304
304305
305306@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -402,24 +403,18 @@ def test_full_cuda_graph(
402403 else :
403404 hf_outputs = None
404405
405- if model not in V0_UNSUPPORTED_MODELS :
406- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
407- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
408- example_prompts , max_tokens , num_logprobs )
409- else :
410- vllm_v0_outputs = None
411-
412406 with monkeypatch .context () as m :
413- m .setenv ("VLLM_USE_V1" , "1" )
414- if model in HYBRID_MODELS :
415- # required due to reorder_batch behaviour
416- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
417- with vllm_runner (model ,
418- max_num_seqs = MAX_NUM_SEQS ,
419- compilation_config = {'full_cuda_graph' : True },
420- enable_prefix_caching = False ) as vllm_model :
421- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
422- example_prompts , max_tokens , num_logprobs )
407+ m .setenv ("VLLM_USE_V1" , "0" )
408+ if model not in V0_UNSUPPORTED_MODELS :
409+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
410+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
411+ example_prompts , max_tokens , num_logprobs )
412+ else :
413+ vllm_v0_outputs = None
414+
415+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
416+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
417+ example_prompts , max_tokens , num_logprobs )
423418
424419 if hf_outputs is not None and vllm_v0_outputs is not None :
425420 check_logprobs_close (
@@ -466,24 +461,20 @@ def test_fp32_state(
466461 else :
467462 hf_outputs = None
468463
469- with vllm_runner (model ,
470- max_num_seqs = MAX_NUM_SEQS ,
471- mamba_ssm_cache_dtype = "float32" ) as vllm_model :
472- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
473- example_prompts , max_tokens , num_logprobs )
474-
475464 with monkeypatch .context () as m :
476- m .setenv ("VLLM_USE_V1" , "1" )
477- if model in HYBRID_MODELS :
478- # required due to reorder_batch behaviour
479- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
465+ m .setenv ("VLLM_USE_V1" , "0" )
480466 with vllm_runner (model ,
481467 max_num_seqs = MAX_NUM_SEQS ,
482- mamba_ssm_cache_dtype = "float32" ,
483- enable_prefix_caching = False ) as vllm_model :
484- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
468+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
469+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
485470 example_prompts , max_tokens , num_logprobs )
486471
472+ with vllm_runner (model ,
473+ max_num_seqs = MAX_NUM_SEQS ,
474+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
475+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
476+ example_prompts , max_tokens , num_logprobs )
477+
487478 if hf_outputs is not None :
488479 check_logprobs_close (
489480 outputs_0_lst = hf_outputs ,
0 commit comments