Skip to content

Commit 9e08696

Browse files
committed
Add E2E test for SeqPar+asyncTP (TODO rms_norm still needed)
Signed-off-by: ProExpertProg <[email protected]>
1 parent 96dd629 commit 9e08696

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,112 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
287287
assert int(log_matches[1]) == matches.allreduce_fusion
288288

289289

290+
# TODO luka resolve
291+
CUSTOM_OPS_RMS_NORM = ["+rms_norm"]
292+
293+
294+
@multi_gpu_test(num_gpus=2)
295+
@pytest.mark.parametrize(
296+
"model_name, model_kwargs, backend, matches, custom_ops",
297+
# Toggle RMSNorm and QuantFP8 for FP8 models
298+
list(
299+
flat_product(
300+
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
301+
)
302+
)
303+
# Toggle RMSNorm for FP4 models and unquant models
304+
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
305+
)
306+
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
307+
@pytest.mark.skipif(
308+
not current_platform.is_cuda(),
309+
reason="sequence parallel only tested on CUDA",
310+
)
311+
def test_tp2_attn_quant_async_tp(
312+
model_name: str,
313+
model_kwargs: dict,
314+
backend: _Backend,
315+
matches: Matches,
316+
custom_ops: str,
317+
inductor_graph_partition: bool,
318+
caplog_mp_spawn,
319+
monkeypatch,
320+
):
321+
if backend == _Backend.FLASHINFER and (
322+
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
323+
):
324+
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
325+
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
326+
pytest.skip("Inductor graph partition requires torch>=2.9")
327+
328+
custom_ops_list = custom_ops.split(",") if custom_ops else []
329+
330+
if inductor_graph_partition:
331+
mode = CUDAGraphMode.FULL_AND_PIECEWISE
332+
splitting_ops: list[str] | None = None
333+
else:
334+
mode = CUDAGraphMode.FULL_DECODE_ONLY
335+
splitting_ops = []
336+
337+
# Disable, compile cache to make sure custom passes run.
338+
# Otherwise, we can't verify fusion happened through the logs.
339+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
340+
341+
# To capture subprocess logs, we need to know whether spawn or fork is used.
342+
# Force spawn as it is more general.
343+
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
344+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
345+
346+
compilation_config = CompilationConfig(
347+
# Testing properties
348+
use_inductor_graph_partition=inductor_graph_partition,
349+
cudagraph_mode=mode,
350+
custom_ops=custom_ops_list,
351+
splitting_ops=splitting_ops,
352+
# Common
353+
level=CompilationMode.VLLM_COMPILE,
354+
pass_config=PassConfig(
355+
enable_attn_fusion=True,
356+
enable_noop=True,
357+
enable_sequence_parallelism=True,
358+
enable_async_tp=True,
359+
),
360+
# Inductor caches custom passes by default as well via uuid
361+
inductor_compile_config={"force_disable_caches": True},
362+
)
363+
364+
with caplog_mp_spawn(logging.DEBUG) as log_holder:
365+
run_model(
366+
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
367+
)
368+
log_matches = re.findall(
369+
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
370+
log_holder.text,
371+
)
372+
assert len(log_matches) == 2, log_holder.text
373+
374+
assert int(log_matches[0]) == matches.attention_fusion
375+
assert int(log_matches[1]) == matches.attention_fusion
376+
377+
log_matches = re.findall(
378+
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
379+
log_holder.text,
380+
)
381+
assert len(log_matches) == 2, log_holder.text
382+
383+
assert int(log_matches[0]) == matches.sequence_parallel
384+
assert int(log_matches[1]) == matches.sequence_parallel
385+
386+
log_matches = re.findall(
387+
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
388+
log_holder.text,
389+
)
390+
assert len(log_matches) == 2, log_holder.text
391+
392+
assert int(log_matches[0]) == matches.async_tp
393+
assert int(log_matches[1]) == matches.async_tp
394+
395+
290396
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
291397
compilation_config = (
292398
compile_config

0 commit comments

Comments
 (0)