Skip to content

Commit d843a67

Browse files
committed
Add triton attn test to attn+quant fusion
Signed-off-by: Luka Govedič <[email protected]>
1 parent 1277999 commit d843a67

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
4+
import itertools
45

56
import pytest
67
import torch._dynamo
@@ -99,6 +100,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
99100
num_blocks = batch_size * max_blocks
100101
backend = self.attn.backend
101102

103+
# TODO use get_kv_cache_stride_order
102104
# Create dummy KV cache for the selected backend
103105
if backend == _Backend.ROCM_ATTN:
104106
# k/v as 1st dimention
@@ -240,7 +242,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
240242
MODELS_FP4 = []
241243
HEADS = []
242244
SPLIT_ATTENTION = []
243-
BACKENDS: list[_Backend] = []
245+
BACKENDS_FP8: list[_Backend] = []
246+
BACKENDS_FP4: list[_Backend] = []
244247

245248
if current_platform.is_cuda():
246249
MODELS_FP8 = [
@@ -251,10 +254,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
251254
]
252255
HEADS = [(64, 8), (40, 8)]
253256
SPLIT_ATTENTION = [False]
254-
BACKENDS = [] # TODO [_Backend.TRITON_ATTN]
257+
BACKENDS_FP8 = [_Backend.TRITON_ATTN]
255258

256259
if current_platform.is_device_capability((10, 0)):
257-
BACKENDS += [_Backend.FLASHINFER]
260+
BACKENDS_FP8 += [_Backend.FLASHINFER]
261+
BACKENDS_FP4 += [_Backend.FLASHINFER]
258262
MODELS_FP4 += [
259263
(
260264
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
@@ -288,13 +292,12 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
288292
)
289293
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
290294
@pytest.mark.parametrize(
291-
"model_name, model_class, custom_ops",
295+
"backend, model, custom_ops",
292296
# Test attention+quant_fp8 fusion with custom and torch impls
293-
[(*model, c) for model in MODELS_FP8 for c in ["+quant_fp8", "-quant_fp8"]]
297+
list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
294298
# quant_fp4 only has the custom impl
295-
+ [(*model, c) for model in MODELS_FP4 for c in [""]],
299+
+ list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])),
296300
)
297-
@pytest.mark.parametrize("backend", BACKENDS)
298301
@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION)
299302
@pytest.mark.skipif(
300303
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
@@ -307,8 +310,7 @@ def test_attention_quant_pattern(
307310
batch_size: int,
308311
dtype: torch.dtype,
309312
custom_ops: str,
310-
model_name: str,
311-
model_class: type[AttentionQuantPatternModel],
313+
model: tuple[str, type[AttentionQuantPatternModel]],
312314
backend: _Backend,
313315
use_inductor_graph_partition: bool,
314316
dist_init,
@@ -317,6 +319,7 @@ def test_attention_quant_pattern(
317319
"""Test AttentionStaticQuantPattern fusion pass"""
318320

319321
custom_ops_list = custom_ops.split(",") if custom_ops else []
322+
model_name, model_class = model
320323

321324
device = torch.device("cuda:0")
322325
torch.manual_seed(42)

0 commit comments

Comments
 (0)