1010from typing import Any , Optional , Union
1111
1212import pytest
13+ import regex as re
1314import torch
1415
1516from tests .quantization .utils import is_quant_method_supported
1617from vllm import LLM , SamplingParams
1718from vllm .attention .backends .registry import _Backend
18- from vllm .attention .selector import global_force_attn_backend_context_manager
1919from vllm .config import CompilationConfig , CompilationLevel , CUDAGraphMode , PassConfig
2020from vllm .platforms import current_platform
2121from vllm .utils import is_torch_equal_or_newer
@@ -235,7 +235,8 @@ def test_fp8_kv_scale_compile(optimization_level: int):
235235)
236236
237237# TODO(luka) test both in nightly
238- CUSTOM_OPS_FP8 = ["-quant_fp8" ] # , "+quant_fp8"]
238+ # TODO(luka) change to -
239+ CUSTOM_OPS_FP8 = ["+quant_fp8" ] # , "+quant_fp8"]
239240
240241
241242@pytest .mark .parametrize (
@@ -252,8 +253,7 @@ def test_e2e_fusion_attn_quant(
252253 backend : _Backend ,
253254 custom_ops : str ,
254255 inductor_graph_partition : bool ,
255- caplog_vllm ,
256- caplog_mp_workaround ,
256+ caplog_mp_spawn ,
257257 monkeypatch ,
258258):
259259 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
@@ -269,7 +269,11 @@ def test_e2e_fusion_attn_quant(
269269 # Otherwise, we can't verify fusion happened through the logs.
270270 # Log capture also doesn't work with multiprocessing yet.
271271 monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
272- # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
272+
273+ # To capture subprocess logs, we need to know whether spawn or fork is used.
274+ # Force spawn as it is more general.
275+ monkeypatch .setenv ("VLLM_WORKER_MULTIPROC_METHOD" , "spawn" )
276+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend .name )
273277
274278 compilation_config = CompilationConfig (
275279 # Testing properties
@@ -284,18 +288,15 @@ def test_e2e_fusion_attn_quant(
284288 inductor_compile_config = {"force_disable_caches" : True },
285289 )
286290
287- with (
288- caplog_vllm .at_level (logging .DEBUG ),
289- caplog_mp_workaround (),
290- global_force_attn_backend_context_manager (backend ),
291- ):
291+ with caplog_mp_spawn (logging .DEBUG ) as log_holder :
292292 run_model (compilation_config , model_name , ** model_kwargs )
293293
294- assert "Fused quant onto 48 attention nodes" in caplog_vllm .text , caplog_vllm .text
294+ assert "Fused quant onto 48 attention nodes" in log_holder .text , log_holder .text
295295
296296
297297# TODO(luka) test both in nightly
298- CUSTOM_OPS_RMS_NORM = ["-rms_norm" ] # , "+rms_norm"]
298+ # TODO(luka) change to -
299+ CUSTOM_OPS_RMS_NORM = ["+rms_norm" ] # , "+rms_norm"]
299300
300301
301302def custom_ops_product (* custom_ops_lists : list [str ]) -> Iterable [str ]:
@@ -321,14 +322,13 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
321322 or not current_platform .has_device_capability ((10 , 0 )),
322323 reason = "allreduce+rmsnorm fusion only supported on blackwell" ,
323324)
324- @pytest .mark .skip (reason = "Still no solution for capturing logs from subprocess" )
325325def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm (
326326 model_name ,
327327 model_kwargs ,
328328 backend ,
329329 custom_ops : str ,
330330 inductor_graph_partition : bool ,
331- caplog_vllm ,
331+ caplog_mp_spawn ,
332332 monkeypatch ,
333333):
334334 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
@@ -344,8 +344,11 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm(
344344 # Otherwise, we can't verify fusion happened through the logs.
345345 # Log capture also doesn't work with multiprocessing yet.
346346 monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
347- # TODO
348- # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
347+
348+ # To capture subprocess logs, we need to know whether spawn or fork is used.
349+ # Force spawn as it is more general.
350+ monkeypatch .setenv ("VLLM_WORKER_MULTIPROC_METHOD" , "spawn" )
351+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend .name )
349352
350353 compilation_config = CompilationConfig (
351354 # Testing properties
@@ -364,18 +367,17 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm(
364367 inductor_compile_config = {"force_disable_caches" : True },
365368 )
366369
367- with (
368- caplog_vllm .at_level (logging .DEBUG ),
369- global_force_attn_backend_context_manager (backend ),
370- ):
370+ with caplog_mp_spawn (logging .DEBUG ) as log_holder :
371371 run_model (
372372 compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
373373 )
374374
375- assert "Fused quant onto 48 attention nodes" in caplog_vllm .text , caplog_vllm .text
375+ assert "Fused quant onto 48 attention nodes" in log_holder .text , log_holder .text
376376
377- # TODO fill in correct number
378- assert "Replaced 96 patterns" in caplog_vllm .text , caplog_vllm .text
377+ matches = re .findall (
378+ r"\[collective_fusion.py:\d+] Replaced 96 patterns" , log_holder .text
379+ )
380+ assert len (matches ) == 2 , log_holder .text
379381
380382
381383def run_model (
0 commit comments