Skip to content

Commit 83a9631

Browse files
ProExpertProgalbertoperdomo2
authored andcommitted
[torch.compile] Enable attention and allreduce fusion without custom ops enabled (vllm-project#24604)
Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
1 parent f589334 commit 83a9631

28 files changed

+1499
-701
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,16 @@ steps:
416416
- pytest -v -s compile/test_basic_correctness.py
417417
- pytest -v -s compile/piecewise/
418418

419-
- label: PyTorch Fullgraph Test # 20min
420-
timeout_in_minutes: 30
419+
- label: PyTorch Fullgraph Test # 22min
420+
timeout_in_minutes: 35
421421
mirror_hardwares: [amdexperimental]
422422
torch_nightly: true
423423
source_file_dependencies:
424424
- vllm/
425425
- tests/compile
426426
commands:
427427
- pytest -v -s compile/test_full_graph.py
428+
- pytest -v -s compile/test_fusions_e2e.py
428429

429430
- label: Kernels Core Operation Test # 48min
430431
timeout_in_minutes: 75
@@ -807,8 +808,8 @@ steps:
807808
# Whisper needs spawn method to avoid deadlock
808809
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
809810

810-
- label: Blackwell Test # 38 min
811-
timeout_in_minutes: 60
811+
- label: Blackwell Test # 21 min
812+
timeout_in_minutes: 30
812813
working_dir: "/vllm-workspace/"
813814
gpu: b200
814815
# optional: true
@@ -821,8 +822,6 @@ steps:
821822
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
822823
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
823824
- vllm/v1/attention/backends/flashinfer.py
824-
- vllm/compilation/fusion.py
825-
- vllm/compilation/fusion_attn.py
826825
commands:
827826
- nvidia-smi
828827
- python3 examples/offline_inference/basic/chat.py
@@ -839,15 +838,32 @@ steps:
839838
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
840839
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
841840
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
841+
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
842+
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
842843
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
843844
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
844-
# Fusion
845-
- pytest -v -s tests/compile/test_fusion_all_reduce.py
846-
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
847845
- pytest -v -s tests/kernels/moe/test_flashinfer.py
846+
847+
- label: Blackwell Fusion Tests # 30 min
848+
timeout_in_minutes: 40
849+
working_dir: "/vllm-workspace/"
850+
gpu: b200
851+
source_file_dependencies:
852+
- csrc/quantization/fp4/
853+
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
854+
- vllm/v1/attention/backends/flashinfer.py
855+
- vllm/compilation/
856+
# can affect pattern matching
857+
- vllm/model_executor/layers/layernorm.py
858+
- vllm/model_executor/layers/activation.py
859+
- vllm/model_executor/layers/quantization/input_quant_fp8.py
860+
commands:
861+
- nvidia-smi
862+
- pytest -v -s tests/compile/test_fusion_attn.py
848863
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
849-
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
850-
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
864+
# this runner has 2 GPUs available even though num_gpus=2 is not set
865+
- pytest -v -s tests/compile/test_fusion_all_reduce.py
866+
- pytest -v -s tests/compile/test_fusions_e2e.py
851867

852868
- label: Blackwell GPT-OSS Eval
853869
timeout_in_minutes: 60
@@ -1100,14 +1116,16 @@ steps:
11001116
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
11011117

11021118
##### H200 test #####
1103-
- label: Distrubted Tests (H200) # optional
1119+
- label: Distributed Tests (H200) # optional
11041120
gpu: h200
11051121
optional: true
11061122
working_dir: "/vllm-workspace/"
11071123
num_gpus: 2
11081124
commands:
11091125
- pytest -v -s tests/compile/test_async_tp.py
11101126
- pytest -v -s tests/compile/test_sequence_parallelism.py
1127+
- pytest -v -s tests/compile/test_fusion_all_reduce.py
1128+
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
11111129
- pytest -v -s tests/distributed/test_context_parallel.py
11121130
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
11131131

csrc/layernorm_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
392392
torch::Tensor& residual, // [..., hidden_size]
393393
torch::Tensor& weight, // [hidden_size]
394394
double epsilon) {
395+
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
396+
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
395397
TORCH_CHECK(residual.is_contiguous());
396398
TORCH_CHECK(weight.is_contiguous());
397399
int hidden_size = input.size(-1);

csrc/layernorm_quant_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
229229
double epsilon) {
230230
TORCH_CHECK(out.is_contiguous());
231231
TORCH_CHECK(residual.is_contiguous());
232+
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
233+
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
232234
int hidden_size = input.size(-1);
233235
int input_stride = input.stride(-2);
234236
int num_tokens = input.numel() / hidden_size;

csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
145145
if (scale_ub.has_value()) {
146146
TORCH_CHECK(out.dtype() == kFp8Type);
147147
}
148+
TORCH_CHECK(weight.dtype() == input.dtype());
148149
TORCH_CHECK(scales.dtype() == torch::kFloat32);
150+
if (residual) {
151+
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
152+
}
149153

150154
VLLM_DISPATCH_FLOATING_TYPES(
151155
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {

tests/compile/backend.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,22 @@
33

44
import weakref
55
from collections.abc import Callable, Sequence
6+
from contextlib import nullcontext
67
from copy import deepcopy
78

9+
import depyf
810
from torch import fx
911
from torch._ops import OpOverload
12+
from torch.fx._utils import lazy_format_graph_code
1013

1114
from vllm.compilation.fx_utils import find_op_nodes
1215
from vllm.compilation.inductor_pass import InductorPass
1316
from vllm.compilation.pass_manager import with_pattern_match_debug
1417
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
1518
from vllm.config import VllmConfig, get_current_vllm_config
19+
from vllm.logger import init_logger
20+
21+
logger = init_logger("vllm.tests.compile.backend")
1622

1723

1824
class LazyInitPass(InductorPass):
@@ -45,20 +51,32 @@ class TestBackend:
4551

4652
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
4753
self.custom_passes = list(passes)
48-
compile_config = get_current_vllm_config().compilation_config
49-
self.inductor_config = compile_config.inductor_compile_config
54+
vllm_config = get_current_vllm_config()
55+
compile_config = vllm_config.compilation_config
56+
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
57+
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
5058
self.inductor_config["force_disable_caches"] = True
5159
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
5260

61+
if debug_dump_path := vllm_config.compile_debug_dump_path():
62+
logger.debug("Dumping depyf output to %s", debug_dump_path)
63+
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
64+
else:
65+
self.debug_ctx = nullcontext()
66+
5367
def __call__(self, graph: fx.GraphModule, example_inputs):
5468
self.graph_pre_compile = deepcopy(graph)
5569
from torch._inductor.compile_fx import compile_fx
5670

57-
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
71+
with self.debug_ctx:
72+
return compile_fx(
73+
graph, example_inputs, config_patches=self.inductor_config
74+
)
5875

5976
@with_pattern_match_debug
6077
def post_pass(self, graph: fx.Graph):
6178
self.graph_pre_pass = deepcopy(graph)
79+
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
6280

6381
VllmInductorPass.dump_prefix = 0
6482
for pass_ in self.custom_passes:
@@ -68,6 +86,7 @@ def post_pass(self, graph: fx.Graph):
6886
VllmInductorPass.dump_prefix = None
6987

7088
self.graph_post_pass = deepcopy(graph)
89+
lazy_format_graph_code("graph_post_pass", graph.owning_module)
7190
# assign by reference, will reflect the final state of the graph
7291
self.final_graph = graph
7392

0 commit comments

Comments
 (0)