|  | 
| 3 | 3 | 
 | 
| 4 | 4 | from __future__ import annotations | 
| 5 | 5 | 
 | 
| 6 |  | -import itertools | 
| 7 |  | -import logging | 
| 8 | 6 | import tempfile | 
| 9 |  | -from collections.abc import Iterable | 
| 10 |  | -from typing import Any, Optional, Union | 
|  | 7 | +from typing import Any, Union | 
| 11 | 8 | 
 | 
| 12 | 9 | import pytest | 
| 13 |  | -import regex as re | 
| 14 | 10 | import torch | 
| 15 | 11 | 
 | 
| 16 | 12 | from tests.quantization.utils import is_quant_method_supported | 
| 17 | 13 | from vllm import LLM, SamplingParams | 
| 18 |  | -from vllm.attention.backends.registry import _Backend | 
| 19 | 14 | from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig | 
| 20 | 15 | from vllm.platforms import current_platform | 
| 21 | 16 | from vllm.utils import is_torch_equal_or_newer | 
| 22 |  | -from vllm.utils.flashinfer import has_flashinfer | 
| 23 | 17 | 
 | 
| 24 |  | -from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test | 
|  | 18 | +from ..utils import create_new_process_for_each_test | 
| 25 | 19 | 
 | 
| 26 | 20 | 
 | 
| 27 | 21 | def models_list(*, all: bool = True, keywords: list[str] | None = None): | 
| @@ -189,194 +183,6 @@ def test_fp8_kv_scale_compile(optimization_level: int): | 
| 189 | 183 |     run_model(optimization_level, model, **model_kwargs) | 
| 190 | 184 | 
 | 
| 191 | 185 | 
 | 
| 192 |  | -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] | 
| 193 |  | -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] | 
| 194 |  | -MODELS: list[tuple[str, dict[str, Any], _Backend]] = []  # tp-only | 
| 195 |  | - | 
| 196 |  | -if current_platform.is_cuda(): | 
| 197 |  | -    MODELS_FP8 += [ | 
| 198 |  | -        ( | 
| 199 |  | -            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", | 
| 200 |  | -            {"max_model_len": 1024}, | 
| 201 |  | -            _Backend.TRITON_ATTN, | 
| 202 |  | -        ) | 
| 203 |  | -    ] | 
| 204 |  | - | 
| 205 |  | -    if current_platform.is_device_capability((10, 0)): | 
| 206 |  | -        MODELS_FP8 += [ | 
| 207 |  | -            ( | 
| 208 |  | -                "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", | 
| 209 |  | -                {"kv_cache_dtype": "fp8", "max_model_len": 1024}, | 
| 210 |  | -                _Backend.FLASHINFER, | 
| 211 |  | -            ) | 
| 212 |  | -        ] | 
| 213 |  | - | 
| 214 |  | -        MODELS_FP4 += [ | 
| 215 |  | -            ( | 
| 216 |  | -                "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", | 
| 217 |  | -                {"kv_cache_dtype": "fp8", "max_model_len": 1024}, | 
| 218 |  | -                _Backend.FLASHINFER, | 
| 219 |  | -            ) | 
| 220 |  | -        ] | 
| 221 |  | - | 
| 222 |  | -        MODELS += [ | 
| 223 |  | -            ( | 
| 224 |  | -                "meta-llama/Llama-3.1-8B-Instruct", | 
| 225 |  | -                {"max_model_len": 1024}, | 
| 226 |  | -                _Backend.FLASHINFER, | 
| 227 |  | -            ) | 
| 228 |  | -        ] | 
| 229 |  | - | 
| 230 |  | -elif current_platform.is_rocm(): | 
| 231 |  | -    MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] | 
| 232 |  | - | 
| 233 |  | -INDUCTOR_GRAPH_PARTITION = ( | 
| 234 |  | -    [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] | 
| 235 |  | -) | 
| 236 |  | - | 
| 237 |  | -# TODO(luka) test both in nightly | 
| 238 |  | -CUSTOM_OPS_FP8 = ["-quant_fp8"]  # , "+quant_fp8"] | 
| 239 |  | - | 
| 240 |  | - | 
| 241 |  | -@pytest.mark.parametrize( | 
| 242 |  | -    "model_name, model_kwargs, backend, custom_ops", | 
| 243 |  | -    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 | 
| 244 |  | -    list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) | 
| 245 |  | -    # quant_fp4 only has the custom impl | 
| 246 |  | -    + list(flat_product(MODELS_FP4, [""])), | 
| 247 |  | -) | 
| 248 |  | -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) | 
| 249 |  | -def test_e2e_fusion_attn_quant( | 
| 250 |  | -    model_name: str, | 
| 251 |  | -    model_kwargs: dict[str, Any], | 
| 252 |  | -    backend: _Backend, | 
| 253 |  | -    custom_ops: str, | 
| 254 |  | -    inductor_graph_partition: bool, | 
| 255 |  | -    caplog_mp_spawn, | 
| 256 |  | -    monkeypatch, | 
| 257 |  | -): | 
| 258 |  | -    custom_ops_list = custom_ops.split(",") if custom_ops else [] | 
| 259 |  | - | 
| 260 |  | -    if inductor_graph_partition: | 
| 261 |  | -        mode = CUDAGraphMode.FULL_AND_PIECEWISE | 
| 262 |  | -        splitting_ops: Optional[list[str]] = None | 
| 263 |  | -    else: | 
| 264 |  | -        mode = CUDAGraphMode.FULL_DECODE_ONLY | 
| 265 |  | -        splitting_ops = [] | 
| 266 |  | - | 
| 267 |  | -    # Disable, compile cache to make sure custom passes run. | 
| 268 |  | -    # Otherwise, we can't verify fusion happened through the logs. | 
| 269 |  | -    # Log capture also doesn't work with multiprocessing yet. | 
| 270 |  | -    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") | 
| 271 |  | - | 
| 272 |  | -    # To capture subprocess logs, we need to know whether spawn or fork is used. | 
| 273 |  | -    # Force spawn as it is more general. | 
| 274 |  | -    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") | 
| 275 |  | -    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) | 
| 276 |  | - | 
| 277 |  | -    compilation_config = CompilationConfig( | 
| 278 |  | -        # Testing properties | 
| 279 |  | -        custom_ops=custom_ops_list, | 
| 280 |  | -        use_inductor_graph_partition=inductor_graph_partition, | 
| 281 |  | -        cudagraph_mode=mode, | 
| 282 |  | -        splitting_ops=splitting_ops, | 
| 283 |  | -        # Common | 
| 284 |  | -        level=CompilationLevel.PIECEWISE, | 
| 285 |  | -        pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), | 
| 286 |  | -        # Inductor caches custom passes by default as well via uuid | 
| 287 |  | -        inductor_compile_config={"force_disable_caches": True}, | 
| 288 |  | -    ) | 
| 289 |  | - | 
| 290 |  | -    with caplog_mp_spawn(logging.DEBUG) as log_holder: | 
| 291 |  | -        run_model(compilation_config, model_name, **model_kwargs) | 
| 292 |  | - | 
| 293 |  | -    assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text | 
| 294 |  | - | 
| 295 |  | - | 
| 296 |  | -# TODO(luka) test both in nightly | 
| 297 |  | -# TODO(luka) change to - | 
| 298 |  | -CUSTOM_OPS_RMS_NORM = ["+rms_norm"]  # , "+rms_norm"] | 
| 299 |  | - | 
| 300 |  | - | 
| 301 |  | -def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: | 
| 302 |  | -    for op_list in itertools.product(*custom_ops_lists): | 
| 303 |  | -        yield ",".join(op_list) | 
| 304 |  | - | 
| 305 |  | - | 
| 306 |  | -@multi_gpu_test(num_gpus=2) | 
| 307 |  | -@pytest.mark.parametrize( | 
| 308 |  | -    "model_name, model_kwargs, backend, custom_ops", | 
| 309 |  | -    # Toggle RMSNorm and QuantFP8 for FP8 models | 
| 310 |  | -    list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) | 
| 311 |  | -    # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO | 
| 312 |  | -    # Toggle RMSNorm for FP4 models and unquant models | 
| 313 |  | -    + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), | 
| 314 |  | -) | 
| 315 |  | -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) | 
| 316 |  | -@pytest.mark.skipif( | 
| 317 |  | -    not current_platform.is_cuda() | 
| 318 |  | -    or not has_flashinfer() | 
| 319 |  | -    or not current_platform.has_device_capability(90), | 
| 320 |  | -    reason="allreduce+rmsnorm fusion requires flashinfer", | 
| 321 |  | -) | 
| 322 |  | -def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( | 
| 323 |  | -    model_name, | 
| 324 |  | -    model_kwargs, | 
| 325 |  | -    backend, | 
| 326 |  | -    custom_ops: str, | 
| 327 |  | -    inductor_graph_partition: bool, | 
| 328 |  | -    caplog_mp_spawn, | 
| 329 |  | -    monkeypatch, | 
| 330 |  | -): | 
| 331 |  | -    custom_ops_list = custom_ops.split(",") if custom_ops else [] | 
| 332 |  | - | 
| 333 |  | -    if inductor_graph_partition: | 
| 334 |  | -        mode = CUDAGraphMode.FULL_AND_PIECEWISE | 
| 335 |  | -        splitting_ops: Optional[list[str]] = None | 
| 336 |  | -    else: | 
| 337 |  | -        mode = CUDAGraphMode.FULL_DECODE_ONLY | 
| 338 |  | -        splitting_ops = [] | 
| 339 |  | - | 
| 340 |  | -    # Disable, compile cache to make sure custom passes run. | 
| 341 |  | -    # Otherwise, we can't verify fusion happened through the logs. | 
| 342 |  | -    # Log capture also doesn't work with multiprocessing yet. | 
| 343 |  | -    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") | 
| 344 |  | - | 
| 345 |  | -    # To capture subprocess logs, we need to know whether spawn or fork is used. | 
| 346 |  | -    # Force spawn as it is more general. | 
| 347 |  | -    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") | 
| 348 |  | -    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) | 
| 349 |  | - | 
| 350 |  | -    compilation_config = CompilationConfig( | 
| 351 |  | -        # Testing properties | 
| 352 |  | -        use_inductor_graph_partition=inductor_graph_partition, | 
| 353 |  | -        cudagraph_mode=mode, | 
| 354 |  | -        custom_ops=custom_ops_list, | 
| 355 |  | -        splitting_ops=splitting_ops, | 
| 356 |  | -        # Common | 
| 357 |  | -        level=CompilationLevel.PIECEWISE, | 
| 358 |  | -        pass_config=PassConfig( | 
| 359 |  | -            enable_attn_fusion=True, | 
| 360 |  | -            enable_noop=True, | 
| 361 |  | -            enable_fi_allreduce_fusion=True, | 
| 362 |  | -        ), | 
| 363 |  | -        # Inductor caches custom passes by default as well via uuid | 
| 364 |  | -        inductor_compile_config={"force_disable_caches": True}, | 
| 365 |  | -    ) | 
| 366 |  | - | 
| 367 |  | -    with caplog_mp_spawn(logging.DEBUG) as log_holder: | 
| 368 |  | -        run_model( | 
| 369 |  | -            compilation_config, model_name, tensor_parallel_size=2, **model_kwargs | 
| 370 |  | -        ) | 
| 371 |  | - | 
| 372 |  | -    assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text | 
| 373 |  | - | 
| 374 |  | -    matches = re.findall( | 
| 375 |  | -        r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text | 
| 376 |  | -    ) | 
| 377 |  | -    assert len(matches) == 2, log_holder.text | 
| 378 |  | - | 
| 379 |  | - | 
| 380 | 186 | def run_model( | 
| 381 | 187 |     compile_config: Union[int, CompilationConfig], model: str, **model_kwargs | 
| 382 | 188 | ): | 
|  | 
0 commit comments