Skip to content

Commit cdd1529

Browse files
committed
Flat product for better test names/visibility
Signed-off-by: Luka Govedič <[email protected]>
1 parent d843a67 commit cdd1529

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
44
import itertools
5+
from collections.abc import Iterable
6+
from typing import Any
57

68
import pytest
79
import torch._dynamo
@@ -285,18 +287,25 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
285287
USE_INDUCTOR_GRAPH_PARTITION = [False]
286288

287289

290+
def flat_product(*iterables: Iterable[Any]):
291+
"""Flatten lists of tuples into cartesian product."""
292+
for element in itertools.product(*iterables):
293+
normalized = (e if isinstance(e, tuple) else [e] for e in element)
294+
yield list(itertools.chain(*normalized))
295+
296+
288297
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
289298
@pytest.mark.parametrize("head_size", [128])
290299
@pytest.mark.parametrize(
291300
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
292301
)
293302
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
294303
@pytest.mark.parametrize(
295-
"backend, model, custom_ops",
296-
# Test attention+quant_fp8 fusion with custom and torch impls
297-
list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
304+
"backend, model_name, model_class, custom_ops",
305+
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
306+
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
298307
# quant_fp4 only has the custom impl
299-
+ list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])),
308+
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
300309
)
301310
@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION)
302311
@pytest.mark.skipif(
@@ -310,7 +319,8 @@ def test_attention_quant_pattern(
310319
batch_size: int,
311320
dtype: torch.dtype,
312321
custom_ops: str,
313-
model: tuple[str, type[AttentionQuantPatternModel]],
322+
model_name: str,
323+
model_class: type[AttentionQuantPatternModel],
314324
backend: _Backend,
315325
use_inductor_graph_partition: bool,
316326
dist_init,
@@ -319,7 +329,6 @@ def test_attention_quant_pattern(
319329
"""Test AttentionStaticQuantPattern fusion pass"""
320330

321331
custom_ops_list = custom_ops.split(",") if custom_ops else []
322-
model_name, model_class = model
323332

324333
device = torch.device("cuda:0")
325334
torch.manual_seed(42)

0 commit comments

Comments
 (0)