@@ -287,6 +287,112 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
287287 assert int (log_matches [1 ]) == matches .allreduce_fusion
288288
289289
290+ # TODO luka resolve
291+ CUSTOM_OPS_RMS_NORM = ["+rms_norm" ]
292+
293+
294+ @multi_gpu_test (num_gpus = 2 )
295+ @pytest .mark .parametrize (
296+ "model_name, model_kwargs, backend, matches, custom_ops" ,
297+ # Toggle RMSNorm and QuantFP8 for FP8 models
298+ list (
299+ flat_product (
300+ MODELS_FP8 , custom_ops_product (CUSTOM_OPS_FP8 , CUSTOM_OPS_RMS_NORM )
301+ )
302+ )
303+ # Toggle RMSNorm for FP4 models and unquant models
304+ + list (flat_product (MODELS_FP4 + MODELS , CUSTOM_OPS_RMS_NORM )),
305+ )
306+ @pytest .mark .parametrize ("inductor_graph_partition" , [True , False ])
307+ @pytest .mark .skipif (
308+ not current_platform .is_cuda (),
309+ reason = "sequence parallel only tested on CUDA" ,
310+ )
311+ def test_tp2_attn_quant_async_tp (
312+ model_name : str ,
313+ model_kwargs : dict ,
314+ backend : _Backend ,
315+ matches : Matches ,
316+ custom_ops : str ,
317+ inductor_graph_partition : bool ,
318+ caplog_mp_spawn ,
319+ monkeypatch ,
320+ ):
321+ if backend == _Backend .FLASHINFER and (
322+ not current_platform .is_device_capability ((10 , 0 )) or not has_flashinfer ()
323+ ):
324+ pytest .skip ("FlashInfer attn fusion requires Blackwell and flashinfer" )
325+ if inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
326+ pytest .skip ("Inductor graph partition requires torch>=2.9" )
327+
328+ custom_ops_list = custom_ops .split ("," ) if custom_ops else []
329+
330+ if inductor_graph_partition :
331+ mode = CUDAGraphMode .FULL_AND_PIECEWISE
332+ splitting_ops : list [str ] | None = None
333+ else :
334+ mode = CUDAGraphMode .FULL_DECODE_ONLY
335+ splitting_ops = []
336+
337+ # Disable, compile cache to make sure custom passes run.
338+ # Otherwise, we can't verify fusion happened through the logs.
339+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
340+
341+ # To capture subprocess logs, we need to know whether spawn or fork is used.
342+ # Force spawn as it is more general.
343+ monkeypatch .setenv ("VLLM_WORKER_MULTIPROC_METHOD" , "spawn" )
344+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend .name )
345+
346+ compilation_config = CompilationConfig (
347+ # Testing properties
348+ use_inductor_graph_partition = inductor_graph_partition ,
349+ cudagraph_mode = mode ,
350+ custom_ops = custom_ops_list ,
351+ splitting_ops = splitting_ops ,
352+ # Common
353+ level = CompilationMode .VLLM_COMPILE ,
354+ pass_config = PassConfig (
355+ enable_attn_fusion = True ,
356+ enable_noop = True ,
357+ enable_sequence_parallelism = True ,
358+ enable_async_tp = True ,
359+ ),
360+ # Inductor caches custom passes by default as well via uuid
361+ inductor_compile_config = {"force_disable_caches" : True },
362+ )
363+
364+ with caplog_mp_spawn (logging .DEBUG ) as log_holder :
365+ run_model (
366+ compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
367+ )
368+ log_matches = re .findall (
369+ r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes" ,
370+ log_holder .text ,
371+ )
372+ assert len (log_matches ) == 2 , log_holder .text
373+
374+ assert int (log_matches [0 ]) == matches .attention_fusion
375+ assert int (log_matches [1 ]) == matches .attention_fusion
376+
377+ log_matches = re .findall (
378+ r"sequence_parallelism.py:\d+] Replaced (\d+) patterns" ,
379+ log_holder .text ,
380+ )
381+ assert len (log_matches ) == 2 , log_holder .text
382+
383+ assert int (log_matches [0 ]) == matches .sequence_parallel
384+ assert int (log_matches [1 ]) == matches .sequence_parallel
385+
386+ log_matches = re .findall (
387+ r"collective_fusion.py:\d+] Replaced (\d+) patterns" ,
388+ log_holder .text ,
389+ )
390+ assert len (log_matches ) == 2 , log_holder .text
391+
392+ assert int (log_matches [0 ]) == matches .async_tp
393+ assert int (log_matches [1 ]) == matches .async_tp
394+
395+
290396def run_model (compile_config : int | CompilationConfig , model : str , ** model_kwargs ):
291397 compilation_config = (
292398 compile_config
0 commit comments