Skip to content

Commit f6429e4

Browse files
committed
Cleanup test_fusion_attn.py
Signed-off-by: Luka Govedič <[email protected]>
1 parent b5f89e5 commit f6429e4

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
3535
from vllm.platforms import current_platform
3636
from vllm.utils import is_torch_equal_or_newer
37+
from vllm.utils.flashinfer import has_flashinfer
3738
from vllm.v1.kv_cache_interface import AttentionSpec
3839

3940
FP8_DTYPE = current_platform.fp8_dtype()
@@ -238,52 +239,41 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
238239
)
239240

240241

241-
MODELS_FP8 = []
242-
MODELS_FP4 = []
243-
HEADS = []
244-
SPLIT_ATTENTION = []
242+
MODELS_FP8: list[tuple[str, type]] = []
243+
MODELS_FP4: list[tuple[str, type]] = []
244+
HEADS: list[tuple[int, int]] = []
245+
SPLIT_ATTENTION: list[bool] = []
245246
BACKENDS_FP8: list[_Backend] = []
246247
BACKENDS_FP4: list[_Backend] = []
247248

248249
if current_platform.is_cuda():
250+
HEADS = [(64, 8), (40, 8)]
249251
MODELS_FP8 = [
250252
(
251253
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
252254
TestAttentionFp8StaticQuantPatternModel,
253255
)
254256
]
255-
HEADS = [(64, 8), (40, 8)]
256-
SPLIT_ATTENTION = [False]
257-
BACKENDS_FP8 = [_Backend.TRITON_ATTN]
258-
259-
if current_platform.is_device_capability((10, 0)):
260-
BACKENDS_FP8 += [_Backend.FLASHINFER]
261-
BACKENDS_FP4 += [_Backend.FLASHINFER]
262-
MODELS_FP4 += [
263-
(
264-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
265-
TestAttentionNvfp4QuantPatternModel,
266-
)
267-
]
257+
MODELS_FP4 = [
258+
(
259+
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
260+
TestAttentionNvfp4QuantPatternModel,
261+
)
262+
]
263+
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
264+
BACKENDS_FP4 = [_Backend.FLASHINFER]
268265

269266
elif current_platform.is_rocm():
267+
HEADS = [(32, 8), (40, 8)]
270268
MODELS_FP8 = [
271269
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
272270
]
273-
HEADS = [(32, 8), (40, 8)]
274-
SPLIT_ATTENTION = [False, True]
275271
BACKENDS = [
276-
_Backend.TRITON_ATTN,
277272
_Backend.ROCM_AITER_UNIFIED_ATTN,
278273
_Backend.ROCM_ATTN,
274+
_Backend.TRITON_ATTN,
279275
]
280276

281-
# TODO(boyuan/luka): test inductor graph partition on rocm
282-
if is_torch_equal_or_newer("2.9.0.dev") and current_platform.is_cuda():
283-
USE_INDUCTOR_GRAPH_PARTITION = [False, True]
284-
else:
285-
USE_INDUCTOR_GRAPH_PARTITION = [False]
286-
287277

288278
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
289279
@pytest.mark.parametrize("head_size", [128])
@@ -298,7 +288,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
298288
# quant_fp4 only has the custom impl
299289
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
300290
)
301-
@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION)
291+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
302292
@pytest.mark.skipif(
303293
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
304294
)
@@ -318,6 +308,14 @@ def test_attention_quant_pattern(
318308
caplog_vllm,
319309
):
320310
"""Test AttentionStaticQuantPattern fusion pass"""
311+
if backend == _Backend.FLASHINFER and (
312+
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
313+
):
314+
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
315+
316+
# TODO(boyuan/luka): test inductor graph partition on rocm
317+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
318+
pytest.skip("Inductor graph partition requires torch>=2.9")
321319

322320
custom_ops_list = custom_ops.split(",") if custom_ops else []
323321

@@ -435,7 +433,7 @@ def test_attention_quant_pattern(
435433
)
436434

437435
# access the underlying `AttnFusionPass` on the `LazyInitPass`
438-
assert attn_pass.pass_.matched_count == 1
436+
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
439437

440438
# Check attention ops in the graph before and after fusion
441439
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))

0 commit comments

Comments
 (0)