11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import copy
4+ import itertools
45
56import pytest
67import torch ._dynamo
@@ -99,6 +100,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
99100 num_blocks = batch_size * max_blocks
100101 backend = self .attn .backend
101102
103+ # TODO use get_kv_cache_stride_order
102104 # Create dummy KV cache for the selected backend
103105 if backend == _Backend .ROCM_ATTN :
104106 # k/v as 1st dimention
@@ -240,7 +242,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
240242MODELS_FP4 = []
241243HEADS = []
242244SPLIT_ATTENTION = []
243- BACKENDS : list [_Backend ] = []
245+ BACKENDS_FP8 : list [_Backend ] = []
246+ BACKENDS_FP4 : list [_Backend ] = []
244247
245248if current_platform .is_cuda ():
246249 MODELS_FP8 = [
@@ -251,10 +254,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
251254 ]
252255 HEADS = [(64 , 8 ), (40 , 8 )]
253256 SPLIT_ATTENTION = [False ]
254- BACKENDS = [] # TODO [_Backend.TRITON_ATTN]
257+ BACKENDS_FP8 = [_Backend .TRITON_ATTN ]
255258
256259 if current_platform .is_device_capability ((10 , 0 )):
257- BACKENDS += [_Backend .FLASHINFER ]
260+ BACKENDS_FP8 += [_Backend .FLASHINFER ]
261+ BACKENDS_FP4 += [_Backend .FLASHINFER ]
258262 MODELS_FP4 += [
259263 (
260264 "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
@@ -288,13 +292,12 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
288292)
289293@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
290294@pytest .mark .parametrize (
291- "model_name, model_class , custom_ops" ,
295+ "backend, model , custom_ops" ,
292296 # Test attention+quant_fp8 fusion with custom and torch impls
293- [( * model , c ) for model in MODELS_FP8 for c in ["+quant_fp8" , "-quant_fp8" ]]
297+ list ( itertools . product ( BACKENDS_FP8 , MODELS_FP8 , ["+quant_fp8" , "-quant_fp8" ]))
294298 # quant_fp4 only has the custom impl
295- + [( * model , c ) for model in MODELS_FP4 for c in ["" ]] ,
299+ + list ( itertools . product ( BACKENDS_FP4 , MODELS_FP4 , ["" ])) ,
296300)
297- @pytest .mark .parametrize ("backend" , BACKENDS )
298301@pytest .mark .parametrize ("use_inductor_graph_partition" , USE_INDUCTOR_GRAPH_PARTITION )
299302@pytest .mark .skipif (
300303 not current_platform .is_cuda_alike (), reason = "Only test ROCm or CUDA"
@@ -307,8 +310,7 @@ def test_attention_quant_pattern(
307310 batch_size : int ,
308311 dtype : torch .dtype ,
309312 custom_ops : str ,
310- model_name : str ,
311- model_class : type [AttentionQuantPatternModel ],
313+ model : tuple [str , type [AttentionQuantPatternModel ]],
312314 backend : _Backend ,
313315 use_inductor_graph_partition : bool ,
314316 dist_init ,
@@ -317,6 +319,7 @@ def test_attention_quant_pattern(
317319 """Test AttentionStaticQuantPattern fusion pass"""
318320
319321 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
322+ model_name , model_class = model
320323
321324 device = torch .device ("cuda:0" )
322325 torch .manual_seed (42 )
0 commit comments