Skip to content

Commit 3ae0f19

Browse files
committed
PR #24604: Squashed commit of the following:
commit c03b29b Author: Luka Govedič <[email protected]> Date: Wed Oct 15 20:31:11 2025 -0400 Remove inductor graph partition from unit test (included in e2e tests) Signed-off-by: Luka Govedič <[email protected]> commit ae581e1 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 20:30:02 2025 -0400 Fix attention fusion test numerics Signed-off-by: Luka Govedič <[email protected]> commit a226864 Merge: e99a759 0a9ef0c Author: Luka Govedič <[email protected]> Date: Wed Oct 15 20:03:52 2025 -0400 Merge branch 'main' into luka/custom-op-matching-2 Signed-off-by: Luka Govedič <[email protected]> commit e99a759 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 19:20:47 2025 -0400 Break up B200 tests, move allreduce to H200 Signed-off-by: Luka Govedič <[email protected]> commit 876ef22 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 18:43:48 2025 -0400 Fix tests, PR feedback Signed-off-by: Luka Govedič <[email protected]> commit 6253d5b Author: Luka Govedič <[email protected]> Date: Wed Oct 15 13:18:03 2025 -0400 Add e2e to L40 distributed, move tests to start of B200 distributed Signed-off-by: Luka Govedič <[email protected]> commit de7405b Author: Luka Govedič <[email protected]> Date: Wed Oct 15 13:08:57 2025 -0400 PR comments: add _custom_op suffix Signed-off-by: Luka Govedič <[email protected]> commit 24f1298 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 13:08:13 2025 -0400 PR comments: cleanup fusion passes, & matching Signed-off-by: Luka Govedič <[email protected]> commit 7e6f5b3 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 13:06:19 2025 -0400 add flat_product example Signed-off-by: Luka Govedič <[email protected]> commit 532cbcf Author: Luka Govedič <[email protected]> Date: Wed Oct 15 12:56:07 2025 -0400 Add comment to test_logger Signed-off-by: Luka Govedič <[email protected]> commit 3943257 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 12:11:29 2025 -0400 Restore original torch.Parameter behavior in RMSNorm Signed-off-by: Luka Govedič <[email protected]> commit a3ebf0a Author: Luka Govedič <[email protected]> Date: Wed Oct 15 12:09:48 2025 -0400 fix fp8 quant tests Signed-off-by: Luka Govedič <[email protected]> commit db2b1c7 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 11:59:35 2025 -0400 Smaller model for e2e fusion test Signed-off-by: Luka Govedič <[email protected]> commit bcd95b5 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 11:54:47 2025 -0400 Fix func test Signed-off-by: Luka Govedič <[email protected]> commit bb0254a Merge: 465ce58 136a17f Author: Luka Govedič <[email protected]> Date: Wed Oct 15 10:01:15 2025 -0400 Merge branch 'main' into luka/custom-op-matching-2 # Conflicts: # tests/utils_/test_utils.py Signed-off-by: Luka Govedič <[email protected]> commit 465ce58 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 09:59:54 2025 -0400 Update tests/compile/test_fusion.py Signed-off-by: Luka Govedič <[email protected]> commit 2a6299c Author: Luka Govedič <[email protected]> Date: Wed Oct 15 04:12:01 2025 -0400 Fix e2e test patterns Signed-off-by: Luka Govedič <[email protected]> commit 8ffb474 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 03:25:26 2025 -0400 Remove/fix TODOs Signed-off-by: Luka Govedič <[email protected]> commit db16ee1 Merge: 12a7c6d f0862ea Author: Luka Govedič <[email protected]> Date: Wed Oct 15 03:10:21 2025 -0400 Merge branch 'main' into luka/custom-op-matching-2 commit 12a7c6d Author: Luka Govedič <[email protected]> Date: Wed Oct 15 03:00:52 2025 -0400 Tests & docs for flat_product Signed-off-by: Luka Govedič <[email protected]> commit 8a363d3 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 02:43:03 2025 -0400 Slight improvement for E2E fusion Signed-off-by: Luka Govedič <[email protected]> commit f6429e4 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 02:40:43 2025 -0400 Cleanup test_fusion_attn.py Signed-off-by: Luka Govedič <[email protected]> commit b5f89e5 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 02:29:06 2025 -0400 Cleanup test_full_graph.py Signed-off-by: Luka Govedič <[email protected]> commit 97b3ff2 Merge: af1ffa7 8c851f6 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 02:08:00 2025 -0400 Merge remote-tracking branch 'upstream/main' into luka/custom-op-matching-2 Signed-off-by: Luka Govedič <[email protected]> commit af1ffa7 Author: Luka Govedič <[email protected]> Date: Wed Oct 15 01:54:18 2025 -0400 PR review Signed-off-by: Luka Govedič <[email protected]> commit 3547b87 Author: Luka Govedič <[email protected]> Date: Sun Oct 12 11:11:14 2025 -0400 fix sequence parallelism test Signed-off-by: Luka Govedič <[email protected]> commit 26892df Author: Luka Govedič <[email protected]> Date: Sun Oct 12 11:03:35 2025 -0400 fix pass manager test Signed-off-by: Luka Govedič <[email protected]> commit 0d6e550 Author: Luka Govedič <[email protected]> Date: Sun Oct 12 10:57:07 2025 -0400 fix func test Signed-off-by: Luka Govedič <[email protected]> commit 1b1a63e Author: Luka Govedič <[email protected]> Date: Sat Oct 11 14:33:46 2025 -0400 Fix e2e allreduce fusion test Signed-off-by: Luka Govedič <[email protected]> commit 52f78ce Author: Luka Govedič <[email protected]> Date: Sat Oct 11 08:38:42 2025 -0400 Add allreduce test to 2-gpu test Signed-off-by: Luka Govedič <[email protected]> commit 095277c Author: Luka Govedič <[email protected]> Date: Fri Oct 10 19:03:18 2025 -0400 Simplify matcher utils by using RMSNorm.forward_static Signed-off-by: Luka Govedič <[email protected]> commit c3264d8 Author: Luka Govedič <[email protected]> Date: Fri Oct 10 18:36:15 2025 -0400 Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match Signed-off-by: Luka Govedič <[email protected]> commit a1c7fdb Author: Luka Govedič <[email protected]> Date: Fri Oct 10 16:13:42 2025 -0400 add more comprehensive testing for allreduce-rmsnorm, fix fp4 (-rmsnorm still failing) Signed-off-by: Luka Govedič <[email protected]> commit 46ee626 Author: Luka Govedič <[email protected]> Date: Fri Oct 10 13:51:13 2025 -0400 add more comprehensive testing for quantfp8 (-rmsnorm+-quant still failing) Signed-off-by: Luka Govedič <[email protected]> commit 32989d8 Author: Luka Govedič <[email protected]> Date: Fri Oct 10 13:49:09 2025 -0400 add pattern for final allreduce in model Signed-off-by: Luka Govedič <[email protected]> commit 5619bc3 Author: Luka Govedič <[email protected]> Date: Thu Oct 9 21:42:34 2025 -0400 clean up e2e tests Signed-off-by: Luka Govedič <[email protected]> commit 1756f67 Author: Luka Govedič <[email protected]> Date: Sat Oct 4 00:06:13 2025 -0400 add back fp4 Signed-off-by: Luka Govedič <[email protected]> commit c653d24 Author: Luka Govedič <[email protected]> Date: Fri Oct 3 23:39:23 2025 -0400 Fix spelling, precommit Signed-off-by: Luka Govedič <[email protected]> commit 31d0127 Author: Luka Govedič <[email protected]> Date: Fri Oct 3 13:01:13 2025 -0400 Add e2e fusions to fullgraph test (should work with Triton backend), disable without flashinfer Signed-off-by: Luka Govedič <[email protected]> commit 4dbfcf7 Author: Luka Govedič <[email protected]> Date: Fri Oct 3 11:49:24 2025 -0400 Move e2e tests to new file, add to test pipeline Signed-off-by: Luka Govedič <[email protected]> commit d3f95fe Author: Luka Govedič <[email protected]> Date: Fri Oct 3 11:38:39 2025 -0400 fullgraph allreduce test update requirements Signed-off-by: Luka Govedič <[email protected]> commit c8675ff Author: Luka Govedič <[email protected]> Date: Thu Oct 2 22:18:24 2025 -0400 log depyf folder, fix context for TestBackend, fix pattern dump Signed-off-by: Luka Govedič <[email protected]> commit d09a278 Author: Luka Govedič <[email protected]> Date: Thu Oct 2 22:16:24 2025 -0400 allreduce fusion working with/without custom ops (with fp4) Signed-off-by: Luka Govedič <[email protected]> commit b7f52bf Author: Luka Govedič <[email protected]> Date: Thu Oct 2 22:12:04 2025 -0400 allreduce fusion working with/without custom ops (except fp4) Signed-off-by: Luka Govedič <[email protected]> commit 54189a9 Author: Luka Govedič <[email protected]> Date: Thu Oct 2 21:24:51 2025 -0400 allreduce fusion working (custom ops on) Signed-off-by: Luka Govedič <[email protected]> commit db479ae Author: Luka Govedič <[email protected]> Date: Thu Oct 2 16:51:30 2025 -0700 TEMP allreduce fusion Signed-off-by: Luka Govedič <[email protected]> commit 5fef180 Author: Luka Govedič <[email protected]> Date: Thu Oct 2 16:35:31 2025 -0700 clean up fullgraph tests Signed-off-by: Luka Govedič <[email protected]> commit 7eb1364 Author: Luka Govedič <[email protected]> Date: Thu Oct 2 19:26:48 2025 -0400 Update csrc/layernorm_kernels.cu Signed-off-by: Luka Govedič <[email protected]> commit 66a35a9 Author: Luka Govedič <[email protected]> Date: Thu Oct 2 19:26:42 2025 -0400 Update tests/compile/backend.py Signed-off-by: Luka Govedič <[email protected]> commit 21a9f9f Author: Luka Govedič <[email protected]> Date: Wed Oct 1 19:02:24 2025 -0700 Fixed tests, passing with 2.8, 2.9 tbd Signed-off-by: Luka Govedič <[email protected]> commit a2aa978 Author: Luka Govedič <[email protected]> Date: Wed Oct 1 11:21:02 2025 -0700 Test for caplog utils Signed-off-by: Luka Govedič <[email protected]> commit eb899a4 Author: Luka Govedič <[email protected]> Date: Tue Sep 30 12:55:33 2025 -0700 Temp MP workaround P3 Signed-off-by: Luka Govedič <[email protected]> commit ae7f56f Author: Luka Govedič <[email protected]> Date: Tue Sep 30 12:50:28 2025 -0700 Temp MP workaround P2 Signed-off-by: Luka Govedič <[email protected]> commit 47b4688 Author: Luka Govedič <[email protected]> Date: Sat Sep 27 07:38:52 2025 -0700 TEMP working on caplog Signed-off-by: Luka Govedič <[email protected]> commit d0b1b56 Author: Luka Govedič <[email protected]> Date: Fri Sep 26 15:39:08 2025 -0700 improve tests by adding more cases Signed-off-by: Luka Govedič <[email protected]> commit 490ac86 Author: Luka Govedič <[email protected]> Date: Fri Sep 26 13:24:01 2025 -0700 Add TP=2 test (untested) Signed-off-by: Luka Govedič <[email protected]> commit c6d6c3b Author: Luka Govedič <[email protected]> Date: Fri Sep 26 13:20:52 2025 -0700 Refactor E2E attn fusion test Signed-off-by: Luka Govedič <[email protected]> commit 141a37e Author: Luka Govedič <[email protected]> Date: Fri Sep 26 07:41:41 2025 -0700 Fix rmsnorm Signed-off-by: Luka Govedič <[email protected]> commit cdd1529 Author: Luka Govedič <[email protected]> Date: Thu Sep 25 17:18:43 2025 -0700 Flat product for better test names/visibility Signed-off-by: Luka Govedič <[email protected]> commit d843a67 Author: Luka Govedič <[email protected]> Date: Thu Sep 25 17:02:14 2025 -0700 Add triton attn test to attn+quant fusion Signed-off-by: Luka Govedič <[email protected]> commit 1277999 Author: Luka Govedič <[email protected]> Date: Thu Sep 25 16:12:23 2025 -0700 Remove V0 attn fusion test Signed-off-by: Luka Govedič <[email protected]> commit 77835fd Author: Luka Govedič <[email protected]> Date: Thu Sep 25 16:12:11 2025 -0700 Attention fusion works with custom ops Signed-off-by: Luka Govedič <[email protected]> commit 1ae80c6 Author: Luka Govedič <[email protected]> Date: Thu Sep 25 16:02:21 2025 -0700 Move global vllm_config to pass manager Signed-off-by: Luka Govedič <[email protected]> commit b172747 Author: Luka Govedič <[email protected]> Date: Thu Sep 25 15:02:33 2025 -0700 Functionalize attn+quant patterns Signed-off-by: Luka Govedič <[email protected]> commit d96913a Author: Luka Govedič <[email protected]> Date: Thu Sep 25 16:06:25 2025 -0400 Cleanup test_fusion.py, added extra layer of rms/quant Signed-off-by: Luka Govedič <[email protected]> commit e6b394e Author: Luka Govedič <[email protected]> Date: Fri Sep 19 19:00:27 2025 -0700 Add TODO Signed-off-by: Luka Govedič <[email protected]> commit 05a65f3 Author: Luka Govedič <[email protected]> Date: Thu Sep 18 13:21:46 2025 -0700 ALL WORKS Signed-off-by: Luka Govedič <[email protected]> commit 14fdc8b Author: Luka Govedič <[email protected]> Date: Thu Sep 18 12:32:27 2025 -0700 quant with fix for pure torch, broke others Signed-off-by: Luka Govedič <[email protected]> commit e151e6d Author: Luka Govedič <[email protected]> Date: Tue Sep 16 11:08:39 2025 -0700 quant works except (torch,torch) Signed-off-by: Luka Govedič <[email protected]> commit 8e4a56f Author: Luka Govedič <[email protected]> Date: Tue Sep 16 10:47:13 2025 -0700 rms works fully now, had to remove more conversions (and add them in replacements). TODO pass to remove unnecessary conversions? Signed-off-by: Luka Govedič <[email protected]> commit cdad3c0 Author: Luka Govedič <[email protected]> Date: Fri Sep 12 12:11:48 2025 -0700 TEMP: fixed rmsnorm issue (TODO assert dtypes in fused norm_quant kernels) Signed-off-by: Luka Govedič <[email protected]> commit f3b4cf1 Author: Luka Govedič <[email protected]> Date: Tue Sep 9 09:48:53 2025 -0700 TEMP Mostly working Signed-off-by: Luka Govedič <[email protected]> commit 21d7d67 Author: Luka Govedič <[email protected]> Date: Sat Sep 6 14:35:13 2025 -0700 Functionalized patterns in prep for utility Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: ProExpertProg <[email protected]>
1 parent 6f1222c commit 3ae0f19

28 files changed

+1498
-695
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,16 @@ steps:
418418
- pytest -v -s compile/test_basic_correctness.py
419419
- pytest -v -s compile/piecewise/
420420

421-
- label: PyTorch Fullgraph Test # 20min
422-
timeout_in_minutes: 30
421+
- label: PyTorch Fullgraph Test # 22min
422+
timeout_in_minutes: 35
423423
mirror_hardwares: [amdexperimental]
424424
torch_nightly: true
425425
source_file_dependencies:
426426
- vllm/
427427
- tests/compile
428428
commands:
429429
- pytest -v -s compile/test_full_graph.py
430+
- pytest -v -s compile/test_fusions_e2e.py
430431

431432
- label: Kernels Core Operation Test # 48min
432433
timeout_in_minutes: 75
@@ -808,8 +809,8 @@ steps:
808809
# Whisper needs spawn method to avoid deadlock
809810
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
810811

811-
- label: Blackwell Test # 38 min
812-
timeout_in_minutes: 60
812+
- label: Blackwell Test # TODO min
813+
timeout_in_minutes: 70
813814
working_dir: "/vllm-workspace/"
814815
gpu: b200
815816
# optional: true
@@ -822,8 +823,6 @@ steps:
822823
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
823824
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
824825
- vllm/v1/attention/backends/flashinfer.py
825-
- vllm/compilation/fusion.py
826-
- vllm/compilation/fusion_attn.py
827826
commands:
828827
- nvidia-smi
829828
- python3 examples/offline_inference/basic/chat.py
@@ -840,15 +839,32 @@ steps:
840839
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
841840
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
842841
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
842+
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
843+
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
843844
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
844845
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
845-
# Fusion
846-
- pytest -v -s tests/compile/test_fusion_all_reduce.py
847-
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
848846
- pytest -v -s tests/kernels/moe/test_flashinfer.py
847+
848+
- label: Blackwell Fusion Tests # TODO min
849+
timeout_in_minutes: 70
850+
working_dir: "/vllm-workspace/"
851+
gpu: b200
852+
source_file_dependencies:
853+
- csrc/quantization/fp4/
854+
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
855+
- vllm/v1/attention/backends/flashinfer.py
856+
- vllm/compilation/
857+
# can affect pattern matching
858+
- vllm/model_executor/layers/layernorm.py
859+
- vllm/model_executor/layers/activation.py
860+
- vllm/model_executor/layers/quantization/input_quant_fp8.py
861+
commands:
862+
- nvidia-smi
863+
- pytest -v -s tests/compile/test_fusion_attn.py
849864
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
850-
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
851-
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
865+
# this runner has 2 GPUs available even though num_gpus=2 is not set
866+
- pytest -v -s tests/compile/test_fusion_all_reduce.py
867+
- pytest -v -s tests/compile/test_fusions_e2e.py
852868

853869
- label: Blackwell GPT-OSS Eval
854870
timeout_in_minutes: 60
@@ -1103,14 +1119,16 @@ steps:
11031119
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
11041120

11051121
##### H200 test #####
1106-
- label: Distrubted Tests (H200) # optional
1122+
- label: Distributed Tests (H200) # optional
11071123
gpu: h200
11081124
optional: true
11091125
working_dir: "/vllm-workspace/"
11101126
num_gpus: 2
11111127
commands:
11121128
- pytest -v -s tests/compile/test_async_tp.py
11131129
- pytest -v -s tests/compile/test_sequence_parallelism.py
1130+
- pytest -v -s tests/compile/test_fusion_all_reduce.py
1131+
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
11141132
- pytest -v -s tests/distributed/test_context_parallel.py
11151133
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
11161134

csrc/layernorm_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
392392
torch::Tensor& residual, // [..., hidden_size]
393393
torch::Tensor& weight, // [hidden_size]
394394
double epsilon) {
395+
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
396+
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
395397
TORCH_CHECK(residual.is_contiguous());
396398
TORCH_CHECK(weight.is_contiguous());
397399
int hidden_size = input.size(-1);

csrc/layernorm_quant_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
229229
double epsilon) {
230230
TORCH_CHECK(out.is_contiguous());
231231
TORCH_CHECK(residual.is_contiguous());
232+
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
233+
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
232234
int hidden_size = input.size(-1);
233235
int input_stride = input.stride(-2);
234236
int num_tokens = input.numel() / hidden_size;

csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
145145
if (scale_ub.has_value()) {
146146
TORCH_CHECK(out.dtype() == kFp8Type);
147147
}
148+
TORCH_CHECK(weight.dtype() == input.dtype());
148149
TORCH_CHECK(scales.dtype() == torch::kFloat32);
150+
if (residual) {
151+
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
152+
}
149153

150154
VLLM_DISPATCH_FLOATING_TYPES(
151155
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {

tests/compile/backend.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,22 @@
33

44
import weakref
55
from collections.abc import Callable, Sequence
6+
from contextlib import nullcontext
67
from copy import deepcopy
78

9+
import depyf
810
from torch import fx
911
from torch._ops import OpOverload
12+
from torch.fx._utils import lazy_format_graph_code
1013

1114
from vllm.compilation.fx_utils import find_op_nodes
1215
from vllm.compilation.inductor_pass import InductorPass
1316
from vllm.compilation.pass_manager import with_pattern_match_debug
1417
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
1518
from vllm.config import VllmConfig, get_current_vllm_config
19+
from vllm.logger import init_logger
20+
21+
logger = init_logger("vllm.tests.compile.backend")
1622

1723

1824
class LazyInitPass(InductorPass):
@@ -45,20 +51,32 @@ class TestBackend:
4551

4652
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
4753
self.custom_passes = list(passes)
48-
compile_config = get_current_vllm_config().compilation_config
49-
self.inductor_config = compile_config.inductor_compile_config
54+
vllm_config = get_current_vllm_config()
55+
compile_config = vllm_config.compilation_config
56+
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
57+
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
5058
self.inductor_config["force_disable_caches"] = True
5159
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
5260

61+
if debug_dump_path := vllm_config.compile_debug_dump_path():
62+
logger.debug("Dumping depyf output to %s", debug_dump_path)
63+
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
64+
else:
65+
self.debug_ctx = nullcontext()
66+
5367
def __call__(self, graph: fx.GraphModule, example_inputs):
5468
self.graph_pre_compile = deepcopy(graph)
5569
from torch._inductor.compile_fx import compile_fx
5670

57-
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
71+
with self.debug_ctx:
72+
return compile_fx(
73+
graph, example_inputs, config_patches=self.inductor_config
74+
)
5875

5976
@with_pattern_match_debug
6077
def post_pass(self, graph: fx.Graph):
6178
self.graph_pre_pass = deepcopy(graph)
79+
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
6280

6381
VllmInductorPass.dump_prefix = 0
6482
for pass_ in self.custom_passes:
@@ -68,6 +86,7 @@ def post_pass(self, graph: fx.Graph):
6886
VllmInductorPass.dump_prefix = None
6987

7088
self.graph_post_pass = deepcopy(graph)
89+
lazy_format_graph_code("graph_post_pass", graph.owning_module)
7190
# assign by reference, will reflect the final state of the graph
7291
self.final_graph = graph
7392

0 commit comments

Comments
 (0)