Skip to content

Conversation

@yifeizhang-c
Copy link
Contributor

@yifeizhang-c yifeizhang-c commented Jul 15, 2025

[TRTLLM-6368] Update DeepEP dispatch API

Description

This PR updates DeepEP's dispatch API to hide topk_idx's value conversion and dtype casting inside kernels to better adapt to current torch.ops.trtllm.fused_moe implementation.

Summary by CodeRabbit

  • New Features

    • Added support for a new offset parameter in expert dispatch operations, enhancing flexibility in expert selection for mixture-of-experts models.
  • Refactor

    • Unified handling of token selection data types and streamlined the dispatch process for expert parallelism, improving code consistency and maintainability.
  • Chores

    • Updated the DeepEP submodule to a newer version.

@yifeizhang-c
Copy link
Contributor Author

yifeizhang-c commented Jul 15, 2025

TODO:

  • Update dtype to int32
  • Check whether passing self.expert_size_per_partition * self.mapping.moe_ep_rank is necessary

@yifeizhang-c
Copy link
Contributor Author

@yifeizhang-c yifeizhang-c force-pushed the dev-yifeiz-update-deepep-dispatch-api branch 3 times, most recently from f7e371a to c06ff41 Compare July 16, 2025 09:44
Signed-off-by: Yifei Zhang <[email protected]>
@yifeizhang-c yifeizhang-c force-pushed the dev-yifeiz-update-deepep-dispatch-api branch from c06ff41 to fb70f7f Compare July 17, 2025 03:05
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 17, 2025

Walkthrough

The changes update the DeepEP submodule version and modify the dispatch method in the VariableLengthBuffer class to accept a new global_expert_id_offset parameter. The forward_chunk method in fused_moe_wide_ep.py is updated to pass this offset and handle token_selected_slots without type conversion, removing redundant adjustments.

Changes

File(s) Change Summary
cpp/tensorrt_llm/deep_ep/CMakeLists.txt Updated DeepEP submodule commit hash to a newer version.
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py Added global_expert_id_offset parameter to VariableLengthBuffer.dispatch method and updated call.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py Updated forward_chunk to pass the offset to dispatch, removed int64 conversion, and simplified logic.

Sequence Diagram(s)

sequenceDiagram
    participant FusedMoeWideEP
    participant VariableLengthBuffer
    participant DeepEPBuffer

    FusedMoeWideEP->>VariableLengthBuffer: dispatch(x, topk_idx, topk_weights, num_experts, global_expert_id_offset)
    VariableLengthBuffer->>DeepEPBuffer: dispatch(..., global_expert_id_offset)
    DeepEPBuffer-->>VariableLengthBuffer: results
    VariableLengthBuffer-->>FusedMoeWideEP: results
Loading

Poem

In the warren where tensors leap,
A buffer learned a secret deep.
With offsets new and slots in tow,
The rabbits made their data flow.
Submodules fresh, the code hops on,
Dispatching dreams from dusk till dawn. 🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e1aee6 and fb70f7f.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (7)
cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1)

1-1: LGTM! Submodule update to newer DeepEP version.

The commit hash update aligns with the PR objectives to update the DeepEP dispatch API. The hash format is correct and follows the expected pattern for git commit references.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (4)

458-461: LGTM! DeepEP dispatch API update with offset parameter.

The changes correctly:

  1. Add the global_expert_id_offset parameter calculated as self.expert_size_per_partition * self.mapping.moe_ep_rank
  2. Update the pad_empty_recv_tensors call to return and use token_selected_slots instead of recv_topk_idx

This aligns with the updated DeepEP dispatch API signature.


464-464: LGTM! Unified data type handling for token_selected_slots.

The removal of .to(torch.int64) conversion unifies the data type handling as mentioned in the PR objectives, allowing the kernel to handle the type conversion internally.


592-595: LGTM! Consistent API update for post-quantization alltoall.

The changes maintain consistency with the earlier dispatch call:

  1. Adding the same offset parameter calculation
  2. Updating the pad_empty_recv_tensors call parameters
  3. Handling the tuple return value correctly for the (x, x_sf) case

The implementation correctly handles the adapter logic for DeepEP data type conversion.


624-624: LGTM! Consistent token handling for DeepEPLowLatency.

The removal of .to(torch.int64) conversion is consistent with the other dispatch calls, maintaining unified data type handling across all DeepEP dispatch methods.

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (2)

62-62: LGTM! API signature update with new offset parameter.

The addition of global_expert_id_offset: int parameter to the dispatch method signature is correct and follows Python type annotation conventions.


79-80: LGTM! Parameter forwarding to underlying buffer.

The global_expert_id_offset parameter is correctly forwarded to the underlying self.buffer.dispatch call, maintaining the API contract between the wrapper and the actual DeepEP buffer implementation.

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@yifeizhang-c yifeizhang-c marked this pull request as ready for review July 17, 2025 03:05
@yifeizhang-c yifeizhang-c requested a review from a team as a code owner July 17, 2025 03:06
@yifeizhang-c
Copy link
Contributor Author

@yuantailing for extra review of this PR (corresponding to task 9.1).

@yifeizhang-c
Copy link
Contributor Author

/bot run

1 similar comment
@yifeizhang-c
Copy link
Contributor Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12171 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12171 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9039 completed with status: 'SUCCESS'

@kaiyux kaiyux merged commit 0155e7a into NVIDIA:main Jul 18, 2025
4 checks passed
reasonsolo pushed a commit to reasonsolo/TensorRT-LLM that referenced this pull request Jul 21, 2025
timlee0212 pushed a commit to timlee0212/TensorRT-LLM that referenced this pull request Jul 21, 2025
@coderabbitai coderabbitai bot mentioned this pull request Jul 24, 2025
NVShreyas pushed a commit to NVShreyas/TensorRT-LLM that referenced this pull request Jul 28, 2025
Signed-off-by: Yifei Zhang <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants