Skip to content

Commit 5619bc3

Browse files
committed
clean up e2e tests
Signed-off-by: Luka Govedič <[email protected]>
1 parent 1756f67 commit 5619bc3

File tree

1 file changed

+99
-48
lines changed

1 file changed

+99
-48
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytest
1212
import regex as re
13+
from black.cache import NamedTuple
1314

1415
from tests.v1.attention.utils import _Backend
1516
from vllm import LLM, SamplingParams
@@ -20,72 +21,111 @@
2021

2122
from ..utils import flat_product, multi_gpu_test
2223

23-
MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = []
24-
MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = []
25-
MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only
24+
25+
class ModelBackendTestCase(NamedTuple):
26+
model_name: str
27+
model_kwargs: dict[str, Any]
28+
backend: _Backend
29+
attention_fusions: int
30+
allreduce_fusions: Optional[int] = None
31+
32+
33+
MODELS_FP8: list[ModelBackendTestCase] = []
34+
MODELS_FP4: list[ModelBackendTestCase] = []
35+
MODELS: list[ModelBackendTestCase] = [] # tp-only
2636

2737
if current_platform.is_cuda():
28-
MODELS_FP8 += [
29-
(
30-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
31-
{"max_model_len": 1024},
32-
_Backend.TRITON_ATTN,
33-
)
38+
MODELS_FP8 = [
39+
ModelBackendTestCase(
40+
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
41+
model_kwargs=dict(max_model_len=1024),
42+
backend=_Backend.TRITON_ATTN,
43+
attention_fusions=48,
44+
allreduce_fusions=96,
45+
),
46+
ModelBackendTestCase(
47+
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
48+
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
49+
backend=_Backend.FLASHINFER,
50+
attention_fusions=48,
51+
allreduce_fusions=96,
52+
),
3453
]
3554

36-
if current_platform.is_device_capability((10, 0)) and has_flashinfer():
37-
MODELS_FP8 += [
38-
(
39-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
40-
{"kv_cache_dtype": "fp8", "max_model_len": 1024},
41-
_Backend.FLASHINFER,
42-
)
43-
]
44-
45-
MODELS_FP4 += [
46-
(
47-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
48-
{"kv_cache_dtype": "fp8", "max_model_len": 1024},
49-
_Backend.FLASHINFER,
50-
)
51-
]
52-
53-
MODELS += [
54-
(
55-
"meta-llama/Llama-3.1-8B-Instruct",
56-
{"max_model_len": 1024},
57-
_Backend.FLASHINFER,
58-
)
59-
]
55+
MODELS_FP4 = [
56+
ModelBackendTestCase(
57+
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
58+
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
59+
backend=_Backend.FLASHINFER,
60+
attention_fusions=48,
61+
allreduce_fusions=96,
62+
),
63+
]
6064

61-
elif current_platform.is_rocm():
62-
MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)]
65+
# TP only
66+
MODELS = [
67+
ModelBackendTestCase(
68+
model_name="meta-llama/Llama-3.1-8B-Instruct",
69+
model_kwargs=dict(max_model_len=1024),
70+
backend=_Backend.TRITON_ATTN,
71+
attention_fusions=0,
72+
allreduce_fusions=64,
73+
),
74+
]
6375

64-
INDUCTOR_GRAPH_PARTITION = (
65-
[True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False]
66-
)
76+
elif current_platform.is_rocm():
77+
MODELS_FP8 = [
78+
ModelBackendTestCase(
79+
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
80+
model_kwargs=dict(max_model_len=1024),
81+
backend=_Backend.TRITON_ATTN,
82+
attention_fusions=32,
83+
),
84+
ModelBackendTestCase(
85+
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
86+
model_kwargs=dict(max_model_len=1024),
87+
backend=_Backend.ROCM_ATTN,
88+
attention_fusions=32,
89+
),
90+
ModelBackendTestCase(
91+
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
92+
model_kwargs=dict(max_model_len=1024),
93+
backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN
94+
attention_fusions=32,
95+
),
96+
]
6797

6898
# TODO(luka) test both in nightly
6999
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
70100

71101

72102
@pytest.mark.parametrize(
73-
"model_name, model_kwargs, backend, custom_ops",
103+
"model_name, model_kwargs, backend, "
104+
"attention_fusions, allreduce_fusions, custom_ops",
74105
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
75106
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
76107
# quant_fp4 only has the custom impl
77108
+ list(flat_product(MODELS_FP4, [""])),
78109
)
79-
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
110+
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
80111
def test_attn_quant(
81112
model_name: str,
82113
model_kwargs: dict[str, Any],
83114
backend: _Backend,
115+
attention_fusions: int,
116+
allreduce_fusions: int,
84117
custom_ops: str,
85118
inductor_graph_partition: bool,
86119
caplog_mp_spawn,
87120
monkeypatch,
88121
):
122+
if backend == _Backend.FLASHINFER and (
123+
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
124+
):
125+
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
126+
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
127+
pytest.skip("Inductor graph partition requires torch>=2.9")
128+
89129
custom_ops_list = custom_ops.split(",") if custom_ops else []
90130

91131
if inductor_graph_partition:
@@ -120,7 +160,9 @@ def test_attn_quant(
120160
with caplog_mp_spawn(logging.DEBUG) as log_holder:
121161
run_model(compilation_config, model_name, **model_kwargs)
122162

123-
assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text
163+
assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, (
164+
log_holder.text
165+
)
124166

125167

126168
# TODO(luka) test both in nightly
@@ -135,29 +177,35 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
135177

136178
@multi_gpu_test(num_gpus=2)
137179
@pytest.mark.parametrize(
138-
"model_name, model_kwargs, backend, custom_ops",
180+
"model_name, model_kwargs, backend, "
181+
"attention_fusions, allreduce_fusions, custom_ops",
139182
# Toggle RMSNorm and QuantFP8 for FP8 models
140183
list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"]))
141184
# custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO
142185
# Toggle RMSNorm for FP4 models and unquant models
143186
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
144187
)
145-
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
188+
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
146189
@pytest.mark.skipif(
147190
not current_platform.is_cuda()
148191
or not has_flashinfer()
149192
or not current_platform.has_device_capability(90),
150193
reason="allreduce+rmsnorm fusion requires flashinfer",
151194
)
152195
def test_tp2_attn_quant_allreduce_rmsnorm(
153-
model_name,
154-
model_kwargs,
155-
backend,
196+
model_name: str,
197+
model_kwargs: dict,
198+
backend: _Backend,
199+
attention_fusions: int,
200+
allreduce_fusions: int,
156201
custom_ops: str,
157202
inductor_graph_partition: bool,
158203
caplog_mp_spawn,
159204
monkeypatch,
160205
):
206+
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
207+
pytest.skip("Inductor graph partition requires torch>=2.9")
208+
161209
custom_ops_list = custom_ops.split(",") if custom_ops else []
162210

163211
if inductor_graph_partition:
@@ -198,10 +246,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
198246
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
199247
)
200248

201-
assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text
249+
assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, (
250+
log_holder.text
251+
)
202252

203253
matches = re.findall(
204-
r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text
254+
rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns",
255+
log_holder.text,
205256
)
206257
assert len(matches) == 2, log_holder.text
207258

0 commit comments

Comments
 (0)