Skip to content

Commit 21a9f9f

Browse files
committed
Fixed tests, passing with 2.8, 2.9 tbd
Signed-off-by: Luka Govedič <[email protected]>
1 parent a2aa978 commit 21a9f9f

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

tests/compile/test_full_graph.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from typing import Any, Optional, Union
1111

1212
import pytest
13+
import regex as re
1314
import torch
1415

1516
from tests.quantization.utils import is_quant_method_supported
1617
from vllm import LLM, SamplingParams
1718
from vllm.attention.backends.registry import _Backend
18-
from vllm.attention.selector import global_force_attn_backend_context_manager
1919
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
2020
from vllm.platforms import current_platform
2121
from vllm.utils import is_torch_equal_or_newer
@@ -235,7 +235,8 @@ def test_fp8_kv_scale_compile(optimization_level: int):
235235
)
236236

237237
# TODO(luka) test both in nightly
238-
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
238+
# TODO(luka) change to -
239+
CUSTOM_OPS_FP8 = ["+quant_fp8"] # , "+quant_fp8"]
239240

240241

241242
@pytest.mark.parametrize(
@@ -252,8 +253,7 @@ def test_e2e_fusion_attn_quant(
252253
backend: _Backend,
253254
custom_ops: str,
254255
inductor_graph_partition: bool,
255-
caplog_vllm,
256-
caplog_mp_workaround,
256+
caplog_mp_spawn,
257257
monkeypatch,
258258
):
259259
custom_ops_list = custom_ops.split(",") if custom_ops else []
@@ -269,7 +269,11 @@ def test_e2e_fusion_attn_quant(
269269
# Otherwise, we can't verify fusion happened through the logs.
270270
# Log capture also doesn't work with multiprocessing yet.
271271
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
272-
# monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
272+
273+
# To capture subprocess logs, we need to know whether spawn or fork is used.
274+
# Force spawn as it is more general.
275+
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
276+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
273277

274278
compilation_config = CompilationConfig(
275279
# Testing properties
@@ -284,18 +288,15 @@ def test_e2e_fusion_attn_quant(
284288
inductor_compile_config={"force_disable_caches": True},
285289
)
286290

287-
with (
288-
caplog_vllm.at_level(logging.DEBUG),
289-
caplog_mp_workaround(),
290-
global_force_attn_backend_context_manager(backend),
291-
):
291+
with caplog_mp_spawn(logging.DEBUG) as log_holder:
292292
run_model(compilation_config, model_name, **model_kwargs)
293293

294-
assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text
294+
assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text
295295

296296

297297
# TODO(luka) test both in nightly
298-
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
298+
# TODO(luka) change to -
299+
CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"]
299300

300301

301302
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@@ -321,14 +322,13 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
321322
or not current_platform.has_device_capability((10, 0)),
322323
reason="allreduce+rmsnorm fusion only supported on blackwell",
323324
)
324-
@pytest.mark.skip(reason="Still no solution for capturing logs from subprocess")
325325
def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm(
326326
model_name,
327327
model_kwargs,
328328
backend,
329329
custom_ops: str,
330330
inductor_graph_partition: bool,
331-
caplog_vllm,
331+
caplog_mp_spawn,
332332
monkeypatch,
333333
):
334334
custom_ops_list = custom_ops.split(",") if custom_ops else []
@@ -344,8 +344,11 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm(
344344
# Otherwise, we can't verify fusion happened through the logs.
345345
# Log capture also doesn't work with multiprocessing yet.
346346
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
347-
# TODO
348-
# monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
347+
348+
# To capture subprocess logs, we need to know whether spawn or fork is used.
349+
# Force spawn as it is more general.
350+
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
351+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
349352

350353
compilation_config = CompilationConfig(
351354
# Testing properties
@@ -364,18 +367,17 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm(
364367
inductor_compile_config={"force_disable_caches": True},
365368
)
366369

367-
with (
368-
caplog_vllm.at_level(logging.DEBUG),
369-
global_force_attn_backend_context_manager(backend),
370-
):
370+
with caplog_mp_spawn(logging.DEBUG) as log_holder:
371371
run_model(
372372
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
373373
)
374374

375-
assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text
375+
assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text
376376

377-
# TODO fill in correct number
378-
assert "Replaced 96 patterns" in caplog_vllm.text, caplog_vllm.text
377+
matches = re.findall(
378+
r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text
379+
)
380+
assert len(matches) == 2, log_holder.text
379381

380382

381383
def run_model(

0 commit comments

Comments
 (0)