Skip to content

Commit 0bb53d7

Browse files
committed
Support FusedMoE LoRA Triton kernel for mxfp4 model
Signed-off-by: Xin Yang <[email protected]>
1 parent 67745d1 commit 0bb53d7

File tree

8 files changed

+99
-43
lines changed

8 files changed

+99
-43
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
2929
FusedMoEModularMethod,
3030
)
31+
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
32+
modular_oai_triton_fused_moe,
33+
)
3134

3235

3336
class FusedMoEWithLoRA(BaseLayerWithLoRA):
@@ -108,15 +111,23 @@ def _inject_lora_into_fused_moe(self):
108111
self.base_layer.ensure_moe_quant_config_init()
109112
quant_config = self.base_layer.quant_method.moe_quant_config
110113

111-
m_fused_moe_fn = (
112-
modular_triton_fused_moe(
113-
quant_config, shared_experts=self.base_layer.shared_experts
114+
if quant_config.use_mxfp4_w4a16:
115+
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Backend
116+
117+
mxfp4_backend = self.base_layer.quant_method.mxfp4_backend
118+
m_fused_moe_fn = (
119+
modular_oai_triton_fused_moe(
120+
quant_config, shared_experts=self.base_layer.shared_experts
121+
)
122+
if mxfp4_backend == Mxfp4Backend.TRITON
123+
else modular_marlin_fused_moe(
124+
quant_config, shared_experts=self.base_layer.shared_experts
125+
)
114126
)
115-
if not quant_config.use_mxfp4_w4a16
116-
else modular_marlin_fused_moe(
127+
else:
128+
m_fused_moe_fn = modular_triton_fused_moe(
117129
quant_config, shared_experts=self.base_layer.shared_experts
118130
)
119-
)
120131

121132
def fwd_decorator(layer, func):
122133
def wrapper(*args, **kwargs):
@@ -279,9 +290,11 @@ def wrapper(*args, **kwargs):
279290
fused_experts.activation = act_decorator(
280291
self.base_layer, fused_experts.activation
281292
)
293+
fused_experts.fuse_act = False
282294
fused_experts.moe_sum = moe_sum_decorator(
283295
self.base_layer, fused_experts.moe_sum
284296
)
297+
fused_experts.fuse_sum = False
285298

286299
self.base_layer.quant_method = FusedMoEModularMethod(
287300
self.base_layer.quant_method, m_fused_moe_fn

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _fused_moe_lora(
430430
== expert_ids.shape[0]
431431
== num_tokens_post_padded.shape[0]
432432
)
433-
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
433+
assert output.shape[-1] // lora_b_stacked[0].shape[-2] == len(lora_b_stacked)
434434
assert output.shape[0] == topk_weights.shape[0]
435435
assert top_k_num == topk_weights.shape[1]
436436
device = qcurr_hidden_states.device

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from collections.abc import Callable
5+
46
import torch
57

68
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9+
from vllm import _custom_ops as ops
710
from vllm.logger import init_logger
811
from vllm.model_executor.layers.fused_moe.config import (
912
FUSED_MOE_UNQUANTIZED_CONFIG,
1013
FusedMoEQuantConfig,
1114
)
15+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
16+
MoEPrepareAndFinalizeNoEP,
17+
)
1218
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1319
TopKWeightAndReduceNoOP,
1420
)
21+
from vllm.model_executor.layers.fused_moe.utils import (
22+
_resize_cache,
23+
)
1524
from vllm.triton_utils import tl, triton
1625
from vllm.utils.import_utils import has_triton_kernels
1726

@@ -96,6 +105,7 @@ def triton_kernel_moe_forward(
96105
routing_data,
97106
gather_idx,
98107
scatter_idx,
108+
topk=topk,
99109
activation=activation,
100110
quant_config=quant_config,
101111
apply_router_weight_on_input=apply_router_weight_on_input,
@@ -113,14 +123,21 @@ def triton_kernel_fused_experts(
113123
routing_data, # RoutingData
114124
gather_indx, # GatherIndx
115125
scatter_indx, # ScatterIndx
126+
topk: int,
116127
activation: str = "silu",
128+
activation_func: Callable[[str, torch.Tensor, torch.Tensor], None] = None,
129+
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
117130
quant_config: FusedMoEQuantConfig | None = None,
118131
swiglu_alpha: float = 1.702,
119132
swiglu_limit: float = 7.0,
120133
apply_router_weight_on_input: bool = False,
121134
global_num_experts: int = -1,
122135
expert_map: torch.Tensor | None = None,
136+
intermediate_cache13: torch.Tensor | None = None,
137+
intermediate_cache2: torch.Tensor | None = None,
123138
a1q_scale: torch.Tensor | None = None,
139+
fuse_act: bool = True,
140+
fuse_sum: bool = True,
124141
) -> torch.Tensor:
125142
if quant_config is None:
126143
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
@@ -134,16 +151,20 @@ def triton_kernel_fused_experts(
134151
assert hidden_states.shape[-1] == w1.shape[-2]
135152
assert w2.shape[-1] == w1.shape[1]
136153

154+
M, K = hidden_states.shape
137155
E, _, N = w1.shape
138156

139157
if global_num_experts == -1:
140158
global_num_experts = E
141159

142-
act = FusedActivation(
143-
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
144-
(swiglu_alpha, swiglu_limit),
145-
2,
146-
)
160+
if not fuse_act:
161+
act = None
162+
else:
163+
act = FusedActivation(
164+
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
165+
(swiglu_alpha, swiglu_limit),
166+
2,
167+
)
147168
gammas = routing_data.gate_scal if routing_data else None
148169

149170
intermediate_cache1 = matmul_ogs(
@@ -157,16 +178,35 @@ def triton_kernel_fused_experts(
157178
fused_activation=act,
158179
)
159180

181+
if not fuse_act:
182+
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N // 2))
183+
activation_func(
184+
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
185+
)
186+
else:
187+
intermediate_cache2 = intermediate_cache1
188+
189+
n_expts_act = routing_data.n_expts_act
190+
if not fuse_sum:
191+
routing_data.n_expts_act = 1
192+
160193
intermediate_cache3 = matmul_ogs(
161-
intermediate_cache1,
194+
intermediate_cache2,
162195
w2,
163196
quant_config.w2_bias,
164197
routing_data,
165198
scatter_indx=scatter_indx,
166199
precision_config=quant_config.w2_precision,
167200
gammas=None if apply_router_weight_on_input else gammas,
168-
y=output_tensor,
169201
)
202+
203+
if not fuse_sum:
204+
moe_sum(intermediate_cache3.view(-1, topk, K), output_tensor)
205+
206+
# Set the original n_expts_act back
207+
routing_data.n_expts_act = n_expts_act
208+
return output_tensor
209+
170210
return intermediate_cache3
171211

172212

@@ -239,6 +279,8 @@ def __init__(self, quant_config: FusedMoEQuantConfig):
239279
# TODO (varun) : Enable activation quantization
240280
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
241281
super().__init__(quant_config)
282+
self.fuse_act = True
283+
self.fuse_sum = True
242284

243285
@property
244286
def activation_formats(
@@ -263,7 +305,7 @@ def workspace_shapes(
263305
expert_tokens_meta: mk.ExpertTokensMetadata | None,
264306
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
265307
# workspace are allocated inside the kernel
266-
workspace1 = (M, K)
308+
workspace1 = (M, topk, max(N // 2, K))
267309
workspace2 = (0, 0)
268310
output = (M, K)
269311
return (workspace1, workspace2, output)
@@ -297,20 +339,39 @@ def apply(
297339
topk_ids, topk_weights, local_num_experts
298340
)
299341

300-
experts_output = triton_kernel_fused_experts(
301-
None,
342+
topk = topk_ids.size(1)
343+
triton_kernel_fused_experts(
344+
output,
302345
hidden_states,
303346
w1,
304347
w2,
305348
routing_data,
306349
gather_indx,
307350
scatter_indx,
351+
topk=topk,
308352
activation=activation,
353+
activation_func=self.activation,
354+
moe_sum=self.moe_sum,
309355
quant_config=self.quant_config,
310356
apply_router_weight_on_input=False,
311357
global_num_experts=local_num_experts,
312358
expert_map=None, # applied already
359+
intermediate_cache13=workspace2,
360+
intermediate_cache2=workspace13,
313361
a1q_scale=a1q_scale,
362+
fuse_act=self.fuse_act,
363+
fuse_sum=self.fuse_sum,
314364
)
315365

316-
output.copy_(experts_output, non_blocking=True)
366+
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
367+
ops.moe_sum(input, output)
368+
369+
370+
def modular_oai_triton_fused_moe(
371+
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
372+
) -> mk.FusedMoEModularKernel:
373+
return mk.FusedMoEModularKernel(
374+
MoEPrepareAndFinalizeNoEP(),
375+
OAITritonExperts(quant_config),
376+
shared_experts,
377+
)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def maybe_roundup_hidden_size(
218218
act_dtype: torch.dtype,
219219
quant_config: QuantizationConfig | None,
220220
moe_parallel_config: FusedMoEParallelConfig,
221-
is_lora_enabled: bool,
222221
) -> int:
223222
"""
224223
Given layer hidden size and MoE configurations, round up hidden_size
@@ -252,7 +251,7 @@ def maybe_roundup_hidden_size(
252251
get_mxfp4_backend,
253252
)
254253

255-
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
254+
current_mxfp4_backend = get_mxfp4_backend()
256255
if (
257256
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
258257
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
@@ -386,7 +385,6 @@ def __init__(
386385
moe_in_dtype,
387386
quant_config,
388387
self.moe_parallel_config,
389-
is_lora_enabled=self.vllm_config.lora_config is not None,
390388
)
391389

392390
# For smuggling this layer into the fused moe custom op

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ def _finalize(
10871087
if not self.prepare_finalize.supports_async():
10881088
assert not dbo_enabled()
10891089

1090-
self.prepare_finalize.finalize(
1090+
output = self.prepare_finalize.finalize(
10911091
output,
10921092
fused_out,
10931093
topk_weights,

vllm/model_executor/layers/fused_moe/prepare_finalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def finalize(
6565
topk_ids: torch.Tensor,
6666
apply_router_weight_on_input: bool,
6767
weight_and_reduce_impl: mk.TopKWeightAndReduce,
68-
) -> None:
68+
) -> torch.Tensor:
6969
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
7070
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
71-
weight_and_reduce_impl.apply(
71+
return weight_and_reduce_impl.apply(
7272
output=output,
7373
fused_expert_output=fused_expert_output,
7474
topk_weights=topk_weights,

vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def apply(
6868
f"But got output={output.size()}, "
6969
f"used_expert_output={fused_expert_output.size()}"
7070
)
71-
output.copy_(fused_expert_output, non_blocking=True)
72-
return output
71+
return fused_expert_output
7372

7473

7574
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,9 @@ class Mxfp4Backend(Enum):
7474
TRITON = 6
7575

7676

77-
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
78-
"""
79-
Not all MXFP4 backends support LoRA. Select backends that are known to
80-
have LoRA support.
81-
"""
82-
if not current_platform.is_cuda():
83-
return Mxfp4Backend.NONE
84-
85-
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
86-
return Mxfp4Backend.MARLIN
87-
88-
89-
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
77+
def get_mxfp4_backend() -> Mxfp4Backend:
9078
# Backend Selection
9179

92-
if with_lora_support:
93-
return get_mxfp4_backend_with_lora()
94-
9580
if current_platform.is_cuda():
9681
if (
9782
current_platform.is_device_capability(90)
@@ -215,7 +200,7 @@ def get_quant_method(
215200
class Mxfp4MoEMethod(FusedMoEMethodBase):
216201
def __init__(self, moe: FusedMoEConfig):
217202
super().__init__(moe)
218-
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
203+
self.mxfp4_backend = get_mxfp4_backend()
219204
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
220205
self.max_capture_size = (
221206
get_current_vllm_config().compilation_config.max_cudagraph_capture_size

0 commit comments

Comments
 (0)