22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
33import  copy 
44import  itertools 
5+ from  collections .abc  import  Iterable 
6+ from  typing  import  Any 
57
68import  pytest 
79import  torch ._dynamo 
@@ -285,18 +287,25 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
285287    USE_INDUCTOR_GRAPH_PARTITION  =  [False ]
286288
287289
290+ def  flat_product (* iterables : Iterable [Any ]):
291+     """Flatten lists of tuples into cartesian product.""" 
292+     for  element  in  itertools .product (* iterables ):
293+         normalized  =  (e  if  isinstance (e , tuple ) else  [e ] for  e  in  element )
294+         yield  list (itertools .chain (* normalized ))
295+ 
296+ 
288297@pytest .mark .parametrize ("num_qo_heads, num_kv_heads" , HEADS ) 
289298@pytest .mark .parametrize ("head_size" , [128 ]) 
290299@pytest .mark .parametrize ( 
291300    "batch_size" , [7 , 256 , 533 ] if  current_platform .is_cuda () else  [8 ] 
292301) 
293302@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ]) 
294303@pytest .mark .parametrize ( 
295-     "backend, model , custom_ops" , 
296-     # Test attention+quant_fp8 fusion with custom and torch impls  
297-     list (itertools . product (BACKENDS_FP8 , MODELS_FP8 , ["+quant_fp8" , "-quant_fp8" ])) 
304+     "backend, model_name, model_class , custom_ops" , 
305+     # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8   
306+     list (flat_product (BACKENDS_FP8 , MODELS_FP8 , ["+quant_fp8" , "-quant_fp8" ])) 
298307    # quant_fp4 only has the custom impl  
299-     +  list (itertools . product (BACKENDS_FP4 , MODELS_FP4 , ["" ])), 
308+     +  list (flat_product (BACKENDS_FP4 , MODELS_FP4 , ["" ])), 
300309) 
301310@pytest .mark .parametrize ("use_inductor_graph_partition" , USE_INDUCTOR_GRAPH_PARTITION ) 
302311@pytest .mark .skipif ( 
@@ -310,7 +319,8 @@ def test_attention_quant_pattern(
310319    batch_size : int ,
311320    dtype : torch .dtype ,
312321    custom_ops : str ,
313-     model : tuple [str , type [AttentionQuantPatternModel ]],
322+     model_name : str ,
323+     model_class : type [AttentionQuantPatternModel ],
314324    backend : _Backend ,
315325    use_inductor_graph_partition : bool ,
316326    dist_init ,
@@ -319,7 +329,6 @@ def test_attention_quant_pattern(
319329    """Test AttentionStaticQuantPattern fusion pass""" 
320330
321331    custom_ops_list  =  custom_ops .split ("," ) if  custom_ops  else  []
322-     model_name , model_class  =  model 
323332
324333    device  =  torch .device ("cuda:0" )
325334    torch .manual_seed (42 )
0 commit comments