Skip to content

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Nov 18, 2025

Purpose

This PR is to support FusedMoE LoRA Triton kernel for mxfp4 model.

  • First matmul_ogs: hidden_states @ w1
    • input hidden_states: [M, K]
    • input w1: [E, K, 2 * N]
    • input gather_indx
      • src_indx: [topk * M]
      • dst_indx: [topk * M]
    • output intermediate_cache1: [topk * M, 2 * N]
      • intermediate_cache1 first dimension is topk * M, because gather expand M rows to topk * M rows.
  • Activation
    • input intermediate_cache1: [topk * M, 2 * N]
    • output intermediate_cache2: [topk * M, N]
      • intermediate_cache2 second dimension is N, because swiglu do 2N -> N collapse.
  • Second matmul_ogs: intermediate_cache1 @ w2 + scatter
    • input intermediate_cache2: [topk * M, N]
    • input w2: [E, N, K]
    • input scatter_idx: inverse of gather_indx, takes per-expert outputs and scatters them back to the original token positions in output tensor
      • src_indx: [topk * M] == gather_idx.dst_indx
      • dst_indx: [topk * M] == gather_idx.src_indx
      • output intermediate_cache3: [topk * M, K]
        • intermediate_cache3 first dimension is topk * M, because of unfuse moe_sum
  • moe_sum
    • input intermediate_cache3: [topk * M, K]
    • output: [M, K]
      • first dimension is M, because moe_sum reduce M * topk rows to M rows

Notes:

  • If use mxfp4 and mxfp4_backend is triton, use OAITritonExperts.
  • Inject lora module in activation: Since matmul_ogs can fuse activation, set fused_activation to None to unfuse activation in the first matmul_ogs.
  • Inject lora module in moe_sum: This need to unfuse sum in second matmul_ogs. Grouped reduction does scatter + accumulate, it is essentially equal to: Y[dst_indx // topk, :] += X[src_indx, :], so that scatter sum across multiple experts, and collapse M * topk to M rows. Therefore, we need to temporarily set routing_data.n_expts_act (which is topk) to 1, so it doesn't sum across multiple experts, in order unfuse moe_sum in the second matmul_ogs.
  • In topk_weight_and_reduce.py, return fused_expert_output directly, instead of copy fused_expert_output to output. Because I found this will lead to nan in output tensor.
  • In fused_moe_lora_op.py, changed to assert to output.shape[-1] // lora_b_stacked[0].shape[-2] == len(lora_b_stacked). Because output.shape[-1] is padded because of mxfp4 swizzle.

Test Plan

Install triton_kernels

pip install "git+https://github.com/triton-lang/triton.git@0a2e3a391cbb9e13d29bf12a2a0005e358102d74#subdirectory=python/triton_kernels"

Baseline (marlin):

VLLM_MXFP4_USE_MARLIN=1 vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --compilation_config '{"compile_sizes": [1, 2, 4, 8, 16]}' \
  --enable-lora \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k \
  --max-lora-rank 64
python3 -m lm_eval --model local-completions \
  --model_args model=lora1,tokenizer=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8021|±  | 0.011|
|     |       |strict-match    |     5|exact_match|↑  |0.8014|±  | 0.011|

PR (triton):

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --compilation_config '{"compile_sizes": [1, 2, 4, 8, 16]}' \
  --enable-lora \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k \
  --max-lora-rank 64
python3 -m lm_eval --model local-completions \
  --model_args model=lora1,tokenizer=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8105|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.8105|±  |0.0108|

Benchmark

vllm bench serve \
  --model openai/gpt-oss-20b \
  --lora-modules lora1 \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 60 \
  --ignore-eos

Baseline (marlin):

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  132.86    
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              7.53      
Output token throughput (tok/s):         1498.06   
Peak output token throughput (tok/s):    1646.00   
Peak concurrent requests:                29.00     
Total Token throughput (tok/s):          3118.65   
---------------Time to First Token----------------
Mean TTFT (ms):                          32.01     
Median TTFT (ms):                        26.94     
P99 TTFT (ms):                           152.01    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.46     
Median TPOT (ms):                        10.33     
P99 TPOT (ms):                           13.47     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.41     
Median ITL (ms):                         9.73      
P99 ITL (ms):                            29.70     
==================================================

PR (triton):

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  99.34     
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              10.07     
Output token throughput (tok/s):         2003.47   
Peak output token throughput (tok/s):    2257.00   
Peak concurrent requests:                35.00     
Total Token throughput (tok/s):          4170.81   
---------------Time to First Token----------------
Mean TTFT (ms):                          23.38     
Median TTFT (ms):                        19.62     
P99 TTFT (ms):                           165.07    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.77      
Median TPOT (ms):                        7.60      
P99 TPOT (ms):                           10.62     
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.74      
Median ITL (ms):                         7.31      
P99 ITL (ms):                            17.60     
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FusedMoE LoRA with Triton kernels for mxfp4 models, which provides a significant performance improvement as shown in the benchmarks. The changes are well-structured, adding the necessary logic to select the Triton backend and adapting the kernels for this new path. However, I've identified a critical issue where attributes in OAITritonExperts are used without being initialized, which could lead to a runtime error in non-LoRA use cases. Please address this to ensure the stability of the implementation.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@xyang16 xyang16 force-pushed the fused_moe_lora_triton branch 3 times, most recently from 8acf2f1 to 8f24eec Compare November 19, 2025 00:59
@robertgshaw2-redhat
Copy link
Collaborator

nice speedup!

@xyang16 xyang16 force-pushed the fused_moe_lora_triton branch from 8f24eec to 0bb53d7 Compare November 19, 2025 01:34
@xyang16 xyang16 force-pushed the fused_moe_lora_triton branch 2 times, most recently from 8697165 to b8fd020 Compare November 19, 2025 01:53
Signed-off-by: Xin Yang <[email protected]>
@xyang16 xyang16 force-pushed the fused_moe_lora_triton branch from b8fd020 to 168d8cd Compare November 19, 2025 02:00
@xyang16 xyang16 changed the title Support FusedMoE LoRA Triton kernel for mxfp4 model [LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 model Nov 19, 2025
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 can you add a unit test for gpt-oss lora + triton_kernels. The test can be predicated on has_triton_kernels like in https://github.com/vllm-project/vllm/blob/main/tests/kernels/moe/test_gpt_oss_triton_kernels.py

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to triton_kernel_fused_experts are invasive and it is a bit confusing reason about the fused_act=True, fuse_sum=True and fuse_act=False,fused_sum=False cases as the assumptions and expectations from matmul_ogs is different in both cases.

The main difference between the non-LoRA and the LoRA case,

  • For the non-LoRA case, no assumptions about the sizes of matmul_ogs output tensors are made. The only requirement here is that the second matmul_ogs must return a tensor of size [M, K]. For the LoRA case, we expect the outputs to be of a specific shape - This pattern is similar to TritonExperts
  • For the non-LoRA case, there are no requirements on the gather_indx and scatter_indx sizes. The LoRA case requires the tensors in these objects to be a specific shape.

For these reasons, I think it will be better to create a separate implementation of BaseOAITritonExperts class for the LoRA case, naming it something like UnfusedOAITritonExperts. Apart of being easier to assert for expectations, with this we can create correct and adequate workspace shapes for both workspace13 and workspace2 and reuse them properly in the implementation. Please refer to TritonExperts I think the implementation here would be very similar and all the logic could be contained within the apply function, thus not disturbing the existing triton_kernel_fused_experts function. something like,

def apply():
        routing_data, gather_indx, scatter_indx = self._make_routing_data(
            topk_ids, topk_weights, local_num_experts
        )
        matmul_ogs(..., 
                             y = intermediate_cache1)
        activation(intermediate_cache2, intermediate_cache1)
       matmul_ogs(...,
                            y = intermediate_cache3)

== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[-1] // lora_b_stacked[0].shape[-2] == len(lora_b_stacked)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change required ? It looks like a * b == c is transformed into c // b == a ? I guess it is required because of the floor operation // . To better understand, can you provide and example where the first one fails and replacement passes. Thanks.

Copy link
Contributor Author

@xyang16 xyang16 Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing! I changed this since output.shape[-1] is padded because of mxfp4 swizzle. I put some explanations for my other changes as well in the Notes in description.


if with_lora_support:
return get_mxfp4_backend_with_lora()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should maintain get_mxfp4_backend_with_lora() and return the appropriate backend from within that function. This is because, there is no guarantee that the logic below will choose a LoRA compatible backend.

a1q_scale=a1q_scale,
)

output.copy_(experts_output, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand why this removal is required. It looks like the output tensor isn't being filled anywhere for the non LoRA case ? Am I missing something ?

also since this class is declared to return TopKWeightAndReduceNoOP() in finalize_weight_and_reduce_impl above, the apply method is expected to fill in the output in the output tensor and other parts of ModularKernel depend on that contract.

I see the redundant copy below in topk_weight_and_reduce.py below, for that I think we should avoid it by doing a .data_ptr() equivalence check between tensors.

precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 To guarantee that matmul_ogs will return an output shaped [M * topk, N], I think it is better to pass in the output tensor ourselves using the argument y. Note that matmul_ogs actually checks if the output is of expected size here https://github.com/triton-lang/triton/blob/c3c476f357f1e9768ea4e45aa5c17528449ab9ef/python/triton_kernels/triton_kernels/matmul_ogs.py#L180 . That way it is guaranteed that matmul_ogs will respect the contract.

Same for the second matmul_ogs also.

apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This violates the base class contract at

f"But got output={output.size()}, "
f"used_expert_output={fused_expert_output.size()}"
)
output.copy_(fused_expert_output, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When output and fused_expert_output tensors are the same, I believe this copy should be avoided by doing,

if (output.data_ptr() != fused_expert_output.data_ptr()):
       output.copy_(fused_expert_output, non_blocking=True)

I think this will prevent changing the signature of MoEPrepareAndFinalizeNoEP::finalize() method.

@github-project-automation github-project-automation bot moved this from To Triage to In progress in gpt-oss Issues & Enhancements Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models

Projects

Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants