Skip to content

Commit ca4f7b2

Browse files
committed
PR vllm-project#26952: Squashed commit of the following:
commit 574cddf Merge: c1dfad6 e6ba200 Author: Boyuan Feng <[email protected]> Date: Thu Oct 16 11:53:09 2025 -0700 Merge branch 'main' into bf/disable-partition-in-custom-op commit c1dfad6 Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 18:05:06 2025 -0700 nit Signed-off-by: Boyuan Feng <[email protected]> commit 6f9339a Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 17:58:07 2025 -0700 use torch.compile options Signed-off-by: Boyuan Feng <[email protected]> commit 0ab7175 Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 17:49:17 2025 -0700 lint Signed-off-by: Boyuan Feng <[email protected]> commit d5d36c3 Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 17:22:53 2025 -0700 Update vllm/model_executor/utils.py Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: Boyuan Feng <[email protected]> commit 04aadb3 Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 17:22:05 2025 -0700 nit Signed-off-by: Boyuan Feng <[email protected]> commit 8e08521 Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 17:17:45 2025 -0700 rewrite as decorator Signed-off-by: Boyuan Feng <[email protected]> commit 29782df Author: Boyuan Feng <[email protected]> Date: Wed Oct 15 16:06:12 2025 -0700 disable graph partition in custom op Signed-off-by: Boyuan Feng <[email protected]> Signed-off-by: ProExpertProg <[email protected]>
1 parent 6ad4712 commit ca4f7b2

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
5050
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
5151
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
52+
from vllm.model_executor.utils import maybe_disable_graph_partition
5253
from vllm.platforms import current_platform
5354
from vllm.triton_utils import tl, triton
5455
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
@@ -1145,7 +1146,11 @@ def fused_topk_bias(
11451146

11461147

11471148
# This is used by the Deepseek-V2 and Deepseek-V3 model
1148-
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
1149+
@torch.compile(
1150+
dynamic=True,
1151+
backend=current_platform.simple_compile_backend,
1152+
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
1153+
)
11491154
def grouped_topk(
11501155
hidden_states: torch.Tensor,
11511156
gating_output: torch.Tensor,

vllm/model_executor/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from vllm.utils import is_torch_equal_or_newer
11+
1012

1113
def set_random_seed(seed: int) -> None:
1214
from vllm.platforms import current_platform
@@ -83,3 +85,10 @@ def get_moe_expert_mapping(
8385
if child_map is not None:
8486
return child_map()
8587
return []
88+
89+
90+
def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]:
91+
if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"):
92+
return {"graph_partition": False}
93+
else:
94+
return {}

0 commit comments

Comments
 (0)