1010
1111import pytest
1212import regex as re
13+ from black .cache import NamedTuple
1314
1415from tests .v1 .attention .utils import _Backend
1516from vllm import LLM , SamplingParams
2021
2122from ..utils import flat_product , multi_gpu_test
2223
23- MODELS_FP8 : list [tuple [str , dict [str , Any ], _Backend ]] = []
24- MODELS_FP4 : list [tuple [str , dict [str , Any ], _Backend ]] = []
25- MODELS : list [tuple [str , dict [str , Any ], _Backend ]] = [] # tp-only
24+
25+ class ModelBackendTestCase (NamedTuple ):
26+ model_name : str
27+ model_kwargs : dict [str , Any ]
28+ backend : _Backend
29+ attention_fusions : int
30+ allreduce_fusions : Optional [int ] = None
31+
32+
33+ MODELS_FP8 : list [ModelBackendTestCase ] = []
34+ MODELS_FP4 : list [ModelBackendTestCase ] = []
35+ MODELS : list [ModelBackendTestCase ] = [] # tp-only
2636
2737if current_platform .is_cuda ():
28- MODELS_FP8 += [
29- (
30- "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
31- {"max_model_len" : 1024 },
32- _Backend .TRITON_ATTN ,
33- )
38+ MODELS_FP8 = [
39+ ModelBackendTestCase (
40+ model_name = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
41+ model_kwargs = dict (max_model_len = 1024 ),
42+ backend = _Backend .TRITON_ATTN ,
43+ attention_fusions = 48 ,
44+ allreduce_fusions = 96 ,
45+ ),
46+ ModelBackendTestCase (
47+ model_name = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
48+ model_kwargs = dict (max_model_len = 1024 , kv_cache_dtype = "fp8" ),
49+ backend = _Backend .FLASHINFER ,
50+ attention_fusions = 48 ,
51+ allreduce_fusions = 96 ,
52+ ),
3453 ]
3554
36- if current_platform .is_device_capability ((10 , 0 )) and has_flashinfer ():
37- MODELS_FP8 += [
38- (
39- "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
40- {"kv_cache_dtype" : "fp8" , "max_model_len" : 1024 },
41- _Backend .FLASHINFER ,
42- )
43- ]
44-
45- MODELS_FP4 += [
46- (
47- "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
48- {"kv_cache_dtype" : "fp8" , "max_model_len" : 1024 },
49- _Backend .FLASHINFER ,
50- )
51- ]
52-
53- MODELS += [
54- (
55- "meta-llama/Llama-3.1-8B-Instruct" ,
56- {"max_model_len" : 1024 },
57- _Backend .FLASHINFER ,
58- )
59- ]
55+ MODELS_FP4 = [
56+ ModelBackendTestCase (
57+ model_name = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
58+ model_kwargs = dict (max_model_len = 1024 , kv_cache_dtype = "fp8" ),
59+ backend = _Backend .FLASHINFER ,
60+ attention_fusions = 48 ,
61+ allreduce_fusions = 96 ,
62+ ),
63+ ]
6064
61- elif current_platform .is_rocm ():
62- MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV" , {}, _Backend .TRITON_ATTN )]
65+ # TP only
66+ MODELS = [
67+ ModelBackendTestCase (
68+ model_name = "meta-llama/Llama-3.1-8B-Instruct" ,
69+ model_kwargs = dict (max_model_len = 1024 ),
70+ backend = _Backend .TRITON_ATTN ,
71+ attention_fusions = 0 ,
72+ allreduce_fusions = 64 ,
73+ ),
74+ ]
6375
64- INDUCTOR_GRAPH_PARTITION = (
65- [True , False ] if (is_torch_equal_or_newer ("2.9.0.dev" )) else [False ]
66- )
76+ elif current_platform .is_rocm ():
77+ MODELS_FP8 = [
78+ ModelBackendTestCase (
79+ model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
80+ model_kwargs = dict (max_model_len = 1024 ),
81+ backend = _Backend .TRITON_ATTN ,
82+ attention_fusions = 32 ,
83+ ),
84+ ModelBackendTestCase (
85+ model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
86+ model_kwargs = dict (max_model_len = 1024 ),
87+ backend = _Backend .ROCM_ATTN ,
88+ attention_fusions = 32 ,
89+ ),
90+ ModelBackendTestCase (
91+ model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
92+ model_kwargs = dict (max_model_len = 1024 ),
93+ backend = _Backend .ROCM_AITER_FA , # TODO ROCM_AITER_UNIFIED_ATTN
94+ attention_fusions = 32 ,
95+ ),
96+ ]
6797
6898# TODO(luka) test both in nightly
6999CUSTOM_OPS_FP8 = ["-quant_fp8" ] # , "+quant_fp8"]
70100
71101
72102@pytest .mark .parametrize (
73- "model_name, model_kwargs, backend, custom_ops" ,
103+ "model_name, model_kwargs, backend, "
104+ "attention_fusions, allreduce_fusions, custom_ops" ,
74105 # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
75106 list (flat_product (MODELS_FP8 , CUSTOM_OPS_FP8 ))
76107 # quant_fp4 only has the custom impl
77108 + list (flat_product (MODELS_FP4 , ["" ])),
78109)
79- @pytest .mark .parametrize ("inductor_graph_partition" , INDUCTOR_GRAPH_PARTITION )
110+ @pytest .mark .parametrize ("inductor_graph_partition" , [ True , False ] )
80111def test_attn_quant (
81112 model_name : str ,
82113 model_kwargs : dict [str , Any ],
83114 backend : _Backend ,
115+ attention_fusions : int ,
116+ allreduce_fusions : int ,
84117 custom_ops : str ,
85118 inductor_graph_partition : bool ,
86119 caplog_mp_spawn ,
87120 monkeypatch ,
88121):
122+ if backend == _Backend .FLASHINFER and (
123+ not current_platform .is_device_capability ((10 , 0 )) or not has_flashinfer ()
124+ ):
125+ pytest .skip ("FlashInfer attn fusion requires Blackwell and flashinfer" )
126+ if inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
127+ pytest .skip ("Inductor graph partition requires torch>=2.9" )
128+
89129 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
90130
91131 if inductor_graph_partition :
@@ -120,7 +160,9 @@ def test_attn_quant(
120160 with caplog_mp_spawn (logging .DEBUG ) as log_holder :
121161 run_model (compilation_config , model_name , ** model_kwargs )
122162
123- assert "Fused quant onto 48 attention nodes" in log_holder .text , log_holder .text
163+ assert f"Fused quant onto { attention_fusions } attention nodes" in log_holder .text , (
164+ log_holder .text
165+ )
124166
125167
126168# TODO(luka) test both in nightly
@@ -135,29 +177,35 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
135177
136178@multi_gpu_test (num_gpus = 2 )
137179@pytest .mark .parametrize (
138- "model_name, model_kwargs, backend, custom_ops" ,
180+ "model_name, model_kwargs, backend, "
181+ "attention_fusions, allreduce_fusions, custom_ops" ,
139182 # Toggle RMSNorm and QuantFP8 for FP8 models
140183 list (flat_product (MODELS_FP8 , ["+quant_fp8,+rms_norm" ]))
141184 # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO
142185 # Toggle RMSNorm for FP4 models and unquant models
143186 + list (flat_product (MODELS_FP4 + MODELS , CUSTOM_OPS_RMS_NORM )),
144187)
145- @pytest .mark .parametrize ("inductor_graph_partition" , INDUCTOR_GRAPH_PARTITION )
188+ @pytest .mark .parametrize ("inductor_graph_partition" , [ True , False ] )
146189@pytest .mark .skipif (
147190 not current_platform .is_cuda ()
148191 or not has_flashinfer ()
149192 or not current_platform .has_device_capability (90 ),
150193 reason = "allreduce+rmsnorm fusion requires flashinfer" ,
151194)
152195def test_tp2_attn_quant_allreduce_rmsnorm (
153- model_name ,
154- model_kwargs ,
155- backend ,
196+ model_name : str ,
197+ model_kwargs : dict ,
198+ backend : _Backend ,
199+ attention_fusions : int ,
200+ allreduce_fusions : int ,
156201 custom_ops : str ,
157202 inductor_graph_partition : bool ,
158203 caplog_mp_spawn ,
159204 monkeypatch ,
160205):
206+ if inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
207+ pytest .skip ("Inductor graph partition requires torch>=2.9" )
208+
161209 custom_ops_list = custom_ops .split ("," ) if custom_ops else []
162210
163211 if inductor_graph_partition :
@@ -198,10 +246,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
198246 compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
199247 )
200248
201- assert "Fused quant onto 48 attention nodes" in log_holder .text , log_holder .text
249+ assert f"Fused quant onto { attention_fusions } attention nodes" in log_holder .text , (
250+ log_holder .text
251+ )
202252
203253 matches = re .findall (
204- r"\[collective_fusion.py:\d+] Replaced 96 patterns" , log_holder .text
254+ rf"\[collective_fusion.py:\d+] Replaced { allreduce_fusions } patterns" ,
255+ log_holder .text ,
205256 )
206257 assert len (matches ) == 2 , log_holder .text
207258
0 commit comments