3434from vllm .model_executor .layers .quantization .utils .w8a8_utils import Fp8LinearOp
3535from vllm .platforms import current_platform
3636from vllm .utils import is_torch_equal_or_newer
37+ from vllm .utils .flashinfer import has_flashinfer
3738from vllm .v1 .kv_cache_interface import AttentionSpec
3839
3940FP8_DTYPE = current_platform .fp8_dtype ()
@@ -238,52 +239,41 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
238239 )
239240
240241
241- MODELS_FP8 = []
242- MODELS_FP4 = []
243- HEADS = []
244- SPLIT_ATTENTION = []
242+ MODELS_FP8 : list [ tuple [ str , type ]] = []
243+ MODELS_FP4 : list [ tuple [ str , type ]] = []
244+ HEADS : list [ tuple [ int , int ]] = []
245+ SPLIT_ATTENTION : list [ bool ] = []
245246BACKENDS_FP8 : list [_Backend ] = []
246247BACKENDS_FP4 : list [_Backend ] = []
247248
248249if current_platform .is_cuda ():
250+ HEADS = [(64 , 8 ), (40 , 8 )]
249251 MODELS_FP8 = [
250252 (
251253 "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
252254 TestAttentionFp8StaticQuantPatternModel ,
253255 )
254256 ]
255- HEADS = [(64 , 8 ), (40 , 8 )]
256- SPLIT_ATTENTION = [False ]
257- BACKENDS_FP8 = [_Backend .TRITON_ATTN ]
258-
259- if current_platform .is_device_capability ((10 , 0 )):
260- BACKENDS_FP8 += [_Backend .FLASHINFER ]
261- BACKENDS_FP4 += [_Backend .FLASHINFER ]
262- MODELS_FP4 += [
263- (
264- "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
265- TestAttentionNvfp4QuantPatternModel ,
266- )
267- ]
257+ MODELS_FP4 = [
258+ (
259+ "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
260+ TestAttentionNvfp4QuantPatternModel ,
261+ )
262+ ]
263+ BACKENDS_FP8 = [_Backend .TRITON_ATTN , _Backend .FLASHINFER ]
264+ BACKENDS_FP4 = [_Backend .FLASHINFER ]
268265
269266elif current_platform .is_rocm ():
267+ HEADS = [(32 , 8 ), (40 , 8 )]
270268 MODELS_FP8 = [
271269 ("amd/Llama-3.1-8B-Instruct-FP8-KV" , TestAttentionFp8StaticQuantPatternModel )
272270 ]
273- HEADS = [(32 , 8 ), (40 , 8 )]
274- SPLIT_ATTENTION = [False , True ]
275271 BACKENDS = [
276- _Backend .TRITON_ATTN ,
277272 _Backend .ROCM_AITER_UNIFIED_ATTN ,
278273 _Backend .ROCM_ATTN ,
274+ _Backend .TRITON_ATTN ,
279275 ]
280276
281- # TODO(boyuan/luka): test inductor graph partition on rocm
282- if is_torch_equal_or_newer ("2.9.0.dev" ) and current_platform .is_cuda ():
283- USE_INDUCTOR_GRAPH_PARTITION = [False , True ]
284- else :
285- USE_INDUCTOR_GRAPH_PARTITION = [False ]
286-
287277
288278@pytest .mark .parametrize ("num_qo_heads, num_kv_heads" , HEADS )
289279@pytest .mark .parametrize ("head_size" , [128 ])
@@ -298,7 +288,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
298288 # quant_fp4 only has the custom impl
299289 + list (flat_product (BACKENDS_FP4 , MODELS_FP4 , ["" ])),
300290)
301- @pytest .mark .parametrize ("use_inductor_graph_partition" , USE_INDUCTOR_GRAPH_PARTITION )
291+ @pytest .mark .parametrize ("use_inductor_graph_partition" , [ True , False ] )
302292@pytest .mark .skipif (
303293 not current_platform .is_cuda_alike (), reason = "Only test ROCm or CUDA"
304294)
@@ -318,6 +308,14 @@ def test_attention_quant_pattern(
318308 caplog_vllm ,
319309):
320310 """Test AttentionStaticQuantPattern fusion pass"""
311+ if backend == _Backend .FLASHINFER and (
312+ not current_platform .is_device_capability ((10 , 0 )) or not has_flashinfer ()
313+ ):
314+ pytest .skip ("FlashInfer attn fusion requires Blackwell and flashinfer" )
315+
316+ # TODO(boyuan/luka): test inductor graph partition on rocm
317+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
318+ pytest .skip ("Inductor graph partition requires torch>=2.9" )
321319
322320 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
323321
@@ -435,7 +433,7 @@ def test_attention_quant_pattern(
435433 )
436434
437435 # access the underlying `AttnFusionPass` on the `LazyInitPass`
438- assert attn_pass .pass_ .matched_count == 1
436+ assert attn_pass .pass_ .matched_count == sum ( attn_fusion_supported )
439437
440438 # Check attention ops in the graph before and after fusion
441439 attn_nodes_pre = list (find_op_nodes (ATTN_OP , test_backend .graph_pre_pass ))
0 commit comments