Skip to content

Commit 485b6b3

Browse files
committed
TEMP: WIP seq par fixes
Signed-off-by: ProExpertProg <[email protected]>
1 parent af31964 commit 485b6b3

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,109 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
287287
assert int(log_matches[1]) == matches.allreduce_fusion
288288

289289

290+
@multi_gpu_test(num_gpus=2)
291+
@pytest.mark.parametrize(
292+
"model_name, model_kwargs, backend, matches, custom_ops",
293+
# Toggle RMSNorm and QuantFP8 for FP8 models
294+
list(
295+
flat_product(
296+
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
297+
)
298+
)
299+
# Toggle RMSNorm for FP4 models and unquant models
300+
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
301+
)
302+
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
303+
@pytest.mark.skipif(
304+
not current_platform.is_cuda(),
305+
reason="sequence parallel only tested on CUDA",
306+
)
307+
def test_tp2_attn_quant_async_tp(
308+
model_name: str,
309+
model_kwargs: dict,
310+
backend: _Backend,
311+
matches: Matches,
312+
custom_ops: str,
313+
inductor_graph_partition: bool,
314+
caplog_mp_spawn,
315+
monkeypatch,
316+
):
317+
pytest.skip("In-progress")
318+
if backend == _Backend.FLASHINFER and (
319+
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
320+
):
321+
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
322+
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
323+
pytest.skip("Inductor graph partition requires torch>=2.9")
324+
325+
custom_ops_list = custom_ops.split(",") if custom_ops else []
326+
327+
if inductor_graph_partition:
328+
mode = CUDAGraphMode.FULL_AND_PIECEWISE
329+
splitting_ops: list[str] | None = None
330+
else:
331+
mode = CUDAGraphMode.FULL_DECODE_ONLY
332+
splitting_ops = []
333+
334+
# Disable, compile cache to make sure custom passes run.
335+
# Otherwise, we can't verify fusion happened through the logs.
336+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
337+
338+
# To capture subprocess logs, we need to know whether spawn or fork is used.
339+
# Force spawn as it is more general.
340+
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
341+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
342+
343+
compilation_config = CompilationConfig(
344+
# Testing properties
345+
use_inductor_graph_partition=inductor_graph_partition,
346+
cudagraph_mode=mode,
347+
custom_ops=custom_ops_list,
348+
splitting_ops=splitting_ops,
349+
# Common
350+
level=CompilationMode.VLLM_COMPILE,
351+
pass_config=PassConfig(
352+
enable_attn_fusion=True,
353+
enable_noop=True,
354+
enable_sequence_parallelism=True,
355+
enable_async_tp=True,
356+
),
357+
# Inductor caches custom passes by default as well via uuid
358+
inductor_compile_config={"force_disable_caches": True},
359+
)
360+
361+
with caplog_mp_spawn(logging.DEBUG) as log_holder:
362+
run_model(
363+
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
364+
)
365+
log_matches = re.findall(
366+
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
367+
log_holder.text,
368+
)
369+
assert len(log_matches) == 2, log_holder.text
370+
371+
assert int(log_matches[0]) == matches.attention_fusion
372+
assert int(log_matches[1]) == matches.attention_fusion
373+
374+
log_matches = re.findall(
375+
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
376+
log_holder.text,
377+
)
378+
assert len(log_matches) == 2, log_holder.text
379+
380+
assert int(log_matches[0]) == matches.sequence_parallel
381+
assert int(log_matches[1]) == matches.sequence_parallel
382+
383+
log_matches = re.findall(
384+
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
385+
log_holder.text,
386+
)
387+
assert len(log_matches) == 2, log_holder.text
388+
389+
assert int(log_matches[0]) == matches.async_tp
390+
assert int(log_matches[1]) == matches.async_tp
391+
392+
290393
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
291394
compilation_config = (
292395
compile_config

tests/compile/test_sequence_parallelism.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def ops_in_model(self):
192192
@pytest.mark.parametrize(
193193
"test_model_cls, custom_ops",
194194
[
195-
(TestModel, ""),
195+
(TestModel, "+rms_norm"),
196+
(TestModel, "-rms_norm"),
196197
(TestQuantModel, "+quant_fp8"),
197198
(TestQuantModel, "-quant_fp8"),
198199
],
@@ -202,6 +203,7 @@ def ops_in_model(self):
202203
@pytest.mark.parametrize("hidden_size", [16])
203204
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
204205
@pytest.mark.parametrize("enable_fusion", [True, False])
206+
@pytest.mark.parametrize("dynamic", [False, True])
205207
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
206208
def test_sequence_parallelism_pass(
207209
test_model_cls: type[torch.nn.Module],
@@ -211,6 +213,7 @@ def test_sequence_parallelism_pass(
211213
hidden_size: int,
212214
dtype: torch.dtype,
213215
enable_fusion: bool,
216+
dynamic: bool,
214217
):
215218
num_processes = 2
216219

@@ -228,6 +231,7 @@ def run_torch_spawn(fn, nprocs):
228231
hidden_size,
229232
dtype,
230233
enable_fusion,
234+
dynamic,
231235
),
232236
nprocs=nprocs,
233237
)
@@ -245,6 +249,7 @@ def sequence_parallelism_pass_on_test_model(
245249
hidden_size: int,
246250
dtype: torch.dtype,
247251
enable_fusion: bool,
252+
dynamic: bool,
248253
):
249254
current_platform.seed_everything(0)
250255

@@ -316,6 +321,9 @@ def sequence_parallelism_pass_on_test_model(
316321

317322
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
318323
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
324+
if dynamic:
325+
torch._dynamo.mark_dynamic(hidden_states, 0)
326+
torch._dynamo.mark_dynamic(residual, 0)
319327

320328
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
321329
compiled_model_no_func(hidden_states, residual)

vllm/config/vllm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,6 @@ def __post_init__(self):
335335
# and requires it to be enabled.
336336
if self.compilation_config.pass_config.enable_async_tp:
337337
self.compilation_config.pass_config.enable_sequence_parallelism = True
338-
if self.compilation_config.pass_config.enable_sequence_parallelism:
339-
self.compilation_config.custom_ops.append("+rms_norm")
340338

341339
if current_platform.support_static_graph_mode():
342340
# if cudagraph_mode is not explicitly set by users, set default

0 commit comments

Comments
 (0)