@@ -287,6 +287,109 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
287287 assert int (log_matches [1 ]) == matches .allreduce_fusion
288288
289289
290+ @multi_gpu_test (num_gpus = 2 )
291+ @pytest .mark .parametrize (
292+ "model_name, model_kwargs, backend, matches, custom_ops" ,
293+ # Toggle RMSNorm and QuantFP8 for FP8 models
294+ list (
295+ flat_product (
296+ MODELS_FP8 , custom_ops_product (CUSTOM_OPS_FP8 , CUSTOM_OPS_RMS_NORM )
297+ )
298+ )
299+ # Toggle RMSNorm for FP4 models and unquant models
300+ + list (flat_product (MODELS_FP4 + MODELS , CUSTOM_OPS_RMS_NORM )),
301+ )
302+ @pytest .mark .parametrize ("inductor_graph_partition" , [True , False ])
303+ @pytest .mark .skipif (
304+ not current_platform .is_cuda (),
305+ reason = "sequence parallel only tested on CUDA" ,
306+ )
307+ def test_tp2_attn_quant_async_tp (
308+ model_name : str ,
309+ model_kwargs : dict ,
310+ backend : _Backend ,
311+ matches : Matches ,
312+ custom_ops : str ,
313+ inductor_graph_partition : bool ,
314+ caplog_mp_spawn ,
315+ monkeypatch ,
316+ ):
317+ pytest .skip ("In-progress" )
318+ if backend == _Backend .FLASHINFER and (
319+ not current_platform .is_device_capability ((10 , 0 )) or not has_flashinfer ()
320+ ):
321+ pytest .skip ("FlashInfer attn fusion requires Blackwell and flashinfer" )
322+ if inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
323+ pytest .skip ("Inductor graph partition requires torch>=2.9" )
324+
325+ custom_ops_list = custom_ops .split ("," ) if custom_ops else []
326+
327+ if inductor_graph_partition :
328+ mode = CUDAGraphMode .FULL_AND_PIECEWISE
329+ splitting_ops : list [str ] | None = None
330+ else :
331+ mode = CUDAGraphMode .FULL_DECODE_ONLY
332+ splitting_ops = []
333+
334+ # Disable, compile cache to make sure custom passes run.
335+ # Otherwise, we can't verify fusion happened through the logs.
336+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
337+
338+ # To capture subprocess logs, we need to know whether spawn or fork is used.
339+ # Force spawn as it is more general.
340+ monkeypatch .setenv ("VLLM_WORKER_MULTIPROC_METHOD" , "spawn" )
341+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend .name )
342+
343+ compilation_config = CompilationConfig (
344+ # Testing properties
345+ use_inductor_graph_partition = inductor_graph_partition ,
346+ cudagraph_mode = mode ,
347+ custom_ops = custom_ops_list ,
348+ splitting_ops = splitting_ops ,
349+ # Common
350+ level = CompilationMode .VLLM_COMPILE ,
351+ pass_config = PassConfig (
352+ enable_attn_fusion = True ,
353+ enable_noop = True ,
354+ enable_sequence_parallelism = True ,
355+ enable_async_tp = True ,
356+ ),
357+ # Inductor caches custom passes by default as well via uuid
358+ inductor_compile_config = {"force_disable_caches" : True },
359+ )
360+
361+ with caplog_mp_spawn (logging .DEBUG ) as log_holder :
362+ run_model (
363+ compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
364+ )
365+ log_matches = re .findall (
366+ r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes" ,
367+ log_holder .text ,
368+ )
369+ assert len (log_matches ) == 2 , log_holder .text
370+
371+ assert int (log_matches [0 ]) == matches .attention_fusion
372+ assert int (log_matches [1 ]) == matches .attention_fusion
373+
374+ log_matches = re .findall (
375+ r"sequence_parallelism.py:\d+] Replaced (\d+) patterns" ,
376+ log_holder .text ,
377+ )
378+ assert len (log_matches ) == 2 , log_holder .text
379+
380+ assert int (log_matches [0 ]) == matches .sequence_parallel
381+ assert int (log_matches [1 ]) == matches .sequence_parallel
382+
383+ log_matches = re .findall (
384+ r"collective_fusion.py:\d+] Replaced (\d+) patterns" ,
385+ log_holder .text ,
386+ )
387+ assert len (log_matches ) == 2 , log_holder .text
388+
389+ assert int (log_matches [0 ]) == matches .async_tp
390+ assert int (log_matches [1 ]) == matches .async_tp
391+
392+
290393def run_model (compile_config : int | CompilationConfig , model : str , ** model_kwargs ):
291394 compilation_config = (
292395 compile_config
0 commit comments