-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 model #28971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FusedMoE LoRA with Triton kernels for mxfp4 models, which provides a significant performance improvement as shown in the benchmarks. The changes are well-structured, adding the necessary logic to select the Triton backend and adapting the kernels for this new path. However, I've identified a critical issue where attributes in OAITritonExperts are used without being initialized, which could lead to a runtime error in non-LoRA use cases. Please address this to ensure the stability of the implementation.
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
8acf2f1 to
8f24eec
Compare
|
nice speedup! |
8f24eec to
0bb53d7
Compare
Signed-off-by: Xin Yang <[email protected]>
8697165 to
b8fd020
Compare
Signed-off-by: Xin Yang <[email protected]>
b8fd020 to
168d8cd
Compare
| modular_triton_fused_moe, | ||
| try_get_optimal_moe_config, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xyang16 can you add a unit test for gpt-oss lora + triton_kernels. The test can be predicated on has_triton_kernels like in https://github.com/vllm-project/vllm/blob/main/tests/kernels/moe/test_gpt_oss_triton_kernels.py
varun-sundar-rabindranath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to triton_kernel_fused_experts are invasive and it is a bit confusing reason about the fused_act=True, fuse_sum=True and fuse_act=False,fused_sum=False cases as the assumptions and expectations from matmul_ogs is different in both cases.
The main difference between the non-LoRA and the LoRA case,
- For the non-LoRA case, no assumptions about the sizes of
matmul_ogsoutput tensors are made. The only requirement here is that the secondmatmul_ogsmust return a tensor of size [M, K]. For the LoRA case, we expect the outputs to be of a specific shape - This pattern is similar to TritonExperts - For the non-LoRA case, there are no requirements on the
gather_indxandscatter_indxsizes. The LoRA case requires the tensors in these objects to be a specific shape.
For these reasons, I think it will be better to create a separate implementation of BaseOAITritonExperts class for the LoRA case, naming it something like UnfusedOAITritonExperts. Apart of being easier to assert for expectations, with this we can create correct and adequate workspace shapes for both workspace13 and workspace2 and reuse them properly in the implementation. Please refer to TritonExperts I think the implementation here would be very similar and all the logic could be contained within the apply function, thus not disturbing the existing triton_kernel_fused_experts function. something like,
def apply():
routing_data, gather_indx, scatter_indx = self._make_routing_data(
topk_ids, topk_weights, local_num_experts
)
matmul_ogs(...,
y = intermediate_cache1)
activation(intermediate_cache2, intermediate_cache1)
matmul_ogs(...,
y = intermediate_cache3)
| == num_tokens_post_padded.shape[0] | ||
| ) | ||
| assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] | ||
| assert output.shape[-1] // lora_b_stacked[0].shape[-2] == len(lora_b_stacked) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change required ? It looks like a * b == c is transformed into c // b == a ? I guess it is required because of the floor operation // . To better understand, can you provide and example where the first one fails and replacement passes. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing! I changed this since output.shape[-1] is padded because of mxfp4 swizzle. I put some explanations for my other changes as well in the Notes in description.
|
|
||
| if with_lora_support: | ||
| return get_mxfp4_backend_with_lora() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should maintain get_mxfp4_backend_with_lora() and return the appropriate backend from within that function. This is because, there is no guarantee that the logic below will choose a LoRA compatible backend.
| a1q_scale=a1q_scale, | ||
| ) | ||
|
|
||
| output.copy_(experts_output, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand why this removal is required. It looks like the output tensor isn't being filled anywhere for the non LoRA case ? Am I missing something ?
also since this class is declared to return TopKWeightAndReduceNoOP() in finalize_weight_and_reduce_impl above, the apply method is expected to fill in the output in the output tensor and other parts of ModularKernel depend on that contract.
I see the redundant copy below in topk_weight_and_reduce.py below, for that I think we should avoid it by doing a .data_ptr() equivalence check between tensors.
| precision_config=quant_config.w1_precision, | ||
| gammas=gammas if apply_router_weight_on_input else None, | ||
| fused_activation=act, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xyang16 To guarantee that matmul_ogs will return an output shaped [M * topk, N], I think it is better to pass in the output tensor ourselves using the argument y. Note that matmul_ogs actually checks if the output is of expected size here https://github.com/triton-lang/triton/blob/c3c476f357f1e9768ea4e45aa5c17528449ab9ef/python/triton_kernels/triton_kernels/matmul_ogs.py#L180 . That way it is guaranteed that matmul_ogs will respect the contract.
Same for the second matmul_ogs also.
| apply_router_weight_on_input: bool, | ||
| weight_and_reduce_impl: mk.TopKWeightAndReduce, | ||
| ) -> None: | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This violates the base class contract at
| def finalize( |
| f"But got output={output.size()}, " | ||
| f"used_expert_output={fused_expert_output.size()}" | ||
| ) | ||
| output.copy_(fused_expert_output, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When output and fused_expert_output tensors are the same, I believe this copy should be avoided by doing,
if (output.data_ptr() != fused_expert_output.data_ptr()):
output.copy_(fused_expert_output, non_blocking=True)
I think this will prevent changing the signature of MoEPrepareAndFinalizeNoEP::finalize() method.
Purpose
This PR is to support FusedMoE LoRA Triton kernel for mxfp4 model.
hidden_states @ w1hidden_states: [M, K]w1: [E, K, 2 * N]gather_indxintermediate_cache1: [topk * M, 2 * N]intermediate_cache1: [topk * M, 2 * N]intermediate_cache2: [topk * M, N]intermediate_cache1 @ w2+ scatterintermediate_cache2: [topk * M, N]w2: [E, N, K]scatter_idx: inverse of gather_indx, takes per-expert outputs and scatters them back to the original token positions in output tensorintermediate_cache3: [topk * M, K]intermediate_cache3: [topk * M, K]output: [M, K]Notes:
Y[dst_indx // topk, :] += X[src_indx, :], so that scatter sum across multiple experts, and collapse M * topk to M rows. Therefore, we need to temporarily setrouting_data.n_expts_act(which is topk) to 1, so it doesn't sum across multiple experts, in order unfuse moe_sum in the second matmul_ogs.output.shape[-1] // lora_b_stacked[0].shape[-2] == len(lora_b_stacked). Because output.shape[-1] is padded because of mxfp4 swizzle.Test Plan
Install triton_kernels
Baseline (marlin):
PR (triton):
Benchmark
Baseline (marlin):
PR (triton):
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.