Skip to content

Conversation

@syuoni
Copy link
Collaborator

@syuoni syuoni commented Jun 24, 2025

[TRTLLM-5965] perf: Optimize MoE sort kernels for large-scale EP

Description

This PR implements the sort logics before MoE GEMMs, and replaces the original CUB sort invocation.

In a typical large-scale EP workload (EP=32 and per-gpu batch=128):

  • Before this PR: 5 kernels

    • buildExpertMapsKernel: 10.3 us
    • CUB sort (three kernels): 11.9 us
    • computeExpertFirstTokenOffsetKernel: 4.6 us
    • In addition, we see significant bubbles between CUB kernels on B200.
      image
  • After this PR: 3 kernels

    • blockExpertPrefixSumKernel: 2.3 us
    • globalExpertPrefixSumKernel: 2.3 us
    • mergeExpertPrefixSumKernel: 2.4 us
      image

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@syuoni syuoni force-pushed the opt-moe-sort branch 2 times, most recently from d003a6b to a9f032d Compare June 26, 2025 06:51
@syuoni syuoni requested review from djns99, dongxuy04, hlu1 and qiaoxj07 June 26, 2025 06:54
@syuoni syuoni self-assigned this Jun 26, 2025
@syuoni syuoni marked this pull request as ready for review June 26, 2025 06:54
@syuoni
Copy link
Collaborator Author

syuoni commented Jun 26, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #9993 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #9993 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #7370 completed with status: 'FAILURE'

@syuoni
Copy link
Collaborator Author

syuoni commented Jun 26, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10027 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10027 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #7399 completed with status: 'SUCCESS'

@syuoni
Copy link
Collaborator Author

syuoni commented Jun 26, 2025

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10043 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10043 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #7411 completed with status: 'FAILURE'

djns99
djns99 previously requested changes Jun 26, 2025
Copy link
Collaborator

@djns99 djns99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this work! I think we need to get rid of this assumption before we can merge this though unfortunately:

// This allows accommodating 256 experts x 64k tokens; reasonable workload should not exceed this

I also think we should try be less wasteful with our block sizes. In the worst assumed case above (assuming topk=8) we are launching 16M threads, of which only 256k contribute anything

Copy link
Collaborator

@djns99 djns99 Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above about using BlockRadixRank we can reduce this to only num_tokens*topk threads.

The final permuted idx is:

selected_expert = token_selected_experts[blockIdx.x * blockDim.x + threadIdx.x];
dest_token_id = expert_first_token_offset[selected_expert] + (block_rank[blockIdx.x][threadIdx.x] - block_exclusive_digit_prefix[blockIdx.x][selected_expert]);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I fully understand your comment. If using BlockRadixRank, what is the gridDim and blockDim?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The total number of threads should be num_tokens*topk we can divide these into blocks however we want. Its an embarassingly parallel operation in the case of mergeExpertPrefixSumKernel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Yes, mergeExpertPrefixSumKernel can be optimized as I reply above, thanks!

@syuoni
Copy link
Collaborator Author

syuoni commented Jun 27, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10173 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10173 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #7511 completed with status: 'SUCCESS'

@syuoni
Copy link
Collaborator Author

syuoni commented Jun 28, 2025

/bot run --add-multi-gpu-test --disable-fail-fast

syuoni added 3 commits June 28, 2025 02:04
Signed-off-by: Enwei Zhu <[email protected]>

refactor

Signed-off-by: Enwei Zhu <[email protected]>

integration

Signed-off-by: Enwei Zhu <[email protected]>

fix large workload

Signed-off-by: Enwei Zhu <[email protected]>

fix PDL

Signed-off-by: Enwei Zhu <[email protected]>

fix

Signed-off-by: Enwei Zhu <[email protected]>

fix large workload

Signed-off-by: Enwei Zhu <[email protected]>

clean unused

Signed-off-by: Enwei Zhu <[email protected]>

fix profiler

Signed-off-by: Enwei Zhu <[email protected]>

move reserve from expandInput

Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
syuoni added 6 commits June 28, 2025 02:06
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
@syuoni
Copy link
Collaborator Author

syuoni commented Jun 28, 2025

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10185 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #10185 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #7518 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@juney-nvidia
Copy link
Collaborator

Let's merge this PR to unblock the E2E optimization of Lage-EP and continue the refinements in the subsequent PRs.

@juney-nvidia juney-nvidia dismissed djns99’s stale review June 29, 2025 17:02

Hi Daniel,

We need to unblock the Large-scale EP E2E performance optimizations and also I noticed that most of the comments left for this PR has been addressed by Enwei, so for now I will unblock the merge of this PR.
Enwei will work with you to discuss the further refinement of the related logics.

Thanks
June

@juney-nvidia juney-nvidia merged commit b4dab23 into NVIDIA:main Jun 29, 2025
3 checks passed
ameynaik-hub pushed a commit to ameynaik-hub/TensorRT-LLM that referenced this pull request Jun 30, 2025
syuoni added a commit to syuoni/TensorRT-LLM that referenced this pull request Jul 1, 2025
Shunkangz pushed a commit to Shunkangz/TensorRT-LLM that referenced this pull request Jul 2, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 9, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 10, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 10, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 10, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 10, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 11, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 11, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Jul 11, 2025
nvzhihanj pushed a commit to nvzhihanj/TensorRT-LLM that referenced this pull request Jul 17, 2025
nvzhihanj pushed a commit to nvzhihanj/TensorRT-LLM that referenced this pull request Jul 26, 2025
@syuoni syuoni deleted the opt-moe-sort branch July 31, 2025 03:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants