Skip to content

Conversation

@hyukn
Copy link
Collaborator

@hyukn hyukn commented Sep 4, 2025

Summary by CodeRabbit

  • New Features

    • Automatic, imbalance-aware tile sizing for MoE, improving adaptability under uneven token-to-expert loads.
    • Dynamic, per–tile-dimension runner selection and caching at runtime for better performance.
  • Refactor

    • Simplified public APIs for MoE runners and entry points by removing the tile dimension parameter and reordering arguments accordingly.
  • Chores

    • Streamlined CUDA-graph warmup by removing an autotuner step during warmup, reducing overhead without changing later autotuning behavior.

Description

tile_tokens_dim directly depends on the num_token, which is a dynamic shape during tuning and inference. When AutoTuner prepares dummy tensors with different num_tokens, it does not update the value of tile_tokens_dim automatically. Therefore, the value stored in the AutoTuner cache is misaligned, which will introduce a lot of cache misses during inference, which hurts perf a lot.

To avoid this issue, we move the calculation of tile_tokens_dim right before kernel launching, so that the value of tile_tokens_dim is always up to date with the num_tokens of the current input tensor used for the kernel runner.

Also, the tile_tokens_dim is calculated based on the number of tokens of a tuned bucket, instead of the original token number. Because we only tune the value for the buckets, not for the raw input token number, to avoid unexpected misalignment between tile_tokens_dim and the token number.

This PR also removes the warmup requests with the extra input shapes, which are triggered in the CUDA graph warmup phase.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@hyukn hyukn requested a review from DomBrown September 4, 2025 08:29
@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch from bff4b4b to bc1970e Compare September 5, 2025 03:22
@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch 2 times, most recently from a26e7c0 to b9f001b Compare September 5, 2025 03:40
@hyukn
Copy link
Collaborator Author

hyukn commented Sep 5, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17744 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17744 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User hyukn is not in the prioritized list. L0 testing cannot be triggered.

@hyukn hyukn changed the title [None][chore] Make tile_tokens_dim calculation just in time before kernel launching. [None][fix] Make tile_tokens_dim calculation just in time before kernel launching. Sep 5, 2025
@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch from b9f001b to 0e3641a Compare September 5, 2025 05:51
@hyukn hyukn marked this pull request as ready for review September 8, 2025 08:09
@hyukn hyukn requested review from a team as code owners September 8, 2025 08:09
@hyukn hyukn requested a review from yizhang-nv September 8, 2025 08:09
@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch from 0e3641a to 6d4cc6e Compare September 8, 2025 08:14
@hyukn
Copy link
Collaborator Author

hyukn commented Sep 8, 2025

/bot run --disable-fail-fast

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 8, 2025

Caution

Review failed

Failed to post review comments.

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ff37048 and 0e3641a.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py (23 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
🧠 Learnings (4)
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
📚 Learning: 2025-08-21T02:39:12.009Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1475-1480
Timestamp: 2025-08-21T02:39:12.009Z
Learning: The min latency mode functionality in TensorRT-LLM MOE kernels (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu) is deprecated and no longer being maintained/updated, as confirmed by djns99. Bug reports and optimization suggestions for the computeStridesTmaWarpSpecializedLowLatencyKernel and related min latency code paths should be deprioritized.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
tensorrt_llm/_torch/utils.py (1)
  • Fp4QuantizedTensor (98-105)
tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py (3)
tensorrt_llm/_torch/autotuner.py (2)
  • AutoTuner (271-752)
  • get (298-301)
cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp (1)
  • routing_logits (300-309)
cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp (1)
  • routing_logits (259-265)
📝 Walkthrough

Walkthrough

The PR updates MoE custom ops to dynamically compute tile size with an optional imbalance factor and to cache TorchScript runners per tile dimension via new get_runner methods. Public constructors drop tile_tokens_dim; call sites now compute tile size at runtime. Fused MoE module removes its local tile computation. PyTorch ModelEngine removes an autotuner warmup branch.

Changes

Cohort / File(s) Summary of Changes
MoE custom ops: dynamic tile-dim runners
tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
Added imbalance_factor to calculate_tile_tokens_dim; introduced per-tile-dim runner caching via get_runner(tile_tokens_dim); removed tile_tokens_dim from runner init and public custom_op entry points; updated forward/config paths to compute tile_tokens_dim at runtime (some with imbalance_factor=1.3) and delegate to per-tile-dim runners.
Fused MoE call sites cleanup
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Deleted local _get_tile_tokens_dim and related import; updated kernel invocations to drop tile_tokens_dim argument; streamlined forward without altering public signatures.
ModelEngine warmup simplification
tensorrt_llm/_torch/pyexecutor/model_engine.py
Removed autotuner warmup block inside CUDA-graph warmup loop; warmup now directly calls forward followed by torch.cuda.synchronize().

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller as CustomOp Entry
  participant RunnerMgr as MoE Runner (Python)
  participant Calc as calculate_tile_tokens_dim()
  participant TSRunner as TorchScript Runner (per tile_dim)
  participant Kernel as MoE Kernel

  Caller->>RunnerMgr: forward(...inputs...)
  RunnerMgr->>Calc: compute tile_tokens_dim(num_tokens, num_experts, top_k, imbalance_factor?)
  note right of Calc: tokens_per_expert *= imbalance_factor<br/>tile_dim = clamp(pow2(...), 8..64)
  Calc-->>RunnerMgr: tile_tokens_dim
  RunnerMgr->>RunnerMgr: get_runner(tile_tokens_dim)
  alt runner cached
    RunnerMgr-->>Caller: reuse cached TSRunner
  else miss
    RunnerMgr->>TSRunner: create/compile TS runner for tile_dim
    RunnerMgr-->>Caller: cache TSRunner
  end
  RunnerMgr->>TSRunner: run_moe(..., tile_tokens_dim)
  TSRunner->>Kernel: dispatch kernel
  Kernel-->>TSRunner: outputs
  TSRunner-->>RunnerMgr: outputs
  RunnerMgr-->>Caller: outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

Release Blocker

Suggested reviewers

  • hlu1
  • yizhang-nv
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18007 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18007 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13495 completed with status: 'FAILURE'

@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch 3 times, most recently from ed59848 to 8ecf44c Compare September 9, 2025 04:47
@hyukn
Copy link
Collaborator Author

hyukn commented Sep 9, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18156 [ run ] triggered by Bot

@hyukn hyukn requested a review from liji-nv September 9, 2025 07:28
@tensorrt-cicd
Copy link
Collaborator

PR_Github #18156 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13604 completed with status: 'FAILURE'

@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch 2 times, most recently from 6f8720e to 2a54d04 Compare September 15, 2025 03:18
…rnel launching.

`tile_tokens_dim` directly depends on the num_token, which is a dynamic shape during tuning and inference. When AutoTuner prepares dummy tensors with different num_tokens, it does not update the value of `tile_tokens_dim` automatically. Therefore, the value stored in the AutoTuner cache is misaligned, which will introduce a lot of cache misses during inference, which hurts perf a lot.

To avoid this issue, we move the calculation of `tile_tokens_dim` right before kernel launching, so that the value of `tile_tokens_dim` is always up to date with the num_tokens of the current input tensor used for the kernel runner. To avoid extra warmup time costs, the extra autotuning warmup steps for all the CUDA graph batch sizes can be removed.

Signed-off-by: Yukun He <[email protected]>
@hyukn hyukn force-pushed the chore/solve_tile_token_dim_calc branch from 2a54d04 to 3ff3040 Compare September 15, 2025 03:20
@hyukn
Copy link
Collaborator Author

hyukn commented Sep 15, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18557 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18557 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13930 completed with status: 'SUCCESS'

@hyukn hyukn merged commit cd80e0a into NVIDIA:main Sep 18, 2025
11 checks passed
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…el launching. (NVIDIA#7529)

tile_tokens_dim directly depends on the num_token, which is a dynamic shape during tuning and inference. When AutoTuner prepares dummy tensors with different num_tokens, it does not update the value of tile_tokens_dim automatically. Therefore, the value stored in the AutoTuner cache is misaligned, which will introduce a lot of cache misses during inference, which hurts perf a lot.

To avoid this issue, we move the calculation of tile_tokens_dim right before kernel launching, so that the value of tile_tokens_dim is always up to date with the num_tokens of the current input tensor used for the kernel runner.

Also, the tile_tokens_dim is calculated based on the number of tokens of a tuned bucket, instead of the original token number. Because we only tune the value for the buckets, not for the raw input token number, to avoid unexpected misalignment between tile_tokens_dim and the token number.

This PR also removes the warmup requests with the extra input shapes, which are triggered in the CUDA graph warmup phase.

Signed-off-by: Yukun He <[email protected]>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
…el launching. (NVIDIA#7529)

tile_tokens_dim directly depends on the num_token, which is a dynamic shape during tuning and inference. When AutoTuner prepares dummy tensors with different num_tokens, it does not update the value of tile_tokens_dim automatically. Therefore, the value stored in the AutoTuner cache is misaligned, which will introduce a lot of cache misses during inference, which hurts perf a lot.

To avoid this issue, we move the calculation of tile_tokens_dim right before kernel launching, so that the value of tile_tokens_dim is always up to date with the num_tokens of the current input tensor used for the kernel runner.

Also, the tile_tokens_dim is calculated based on the number of tokens of a tuned bucket, instead of the original token number. Because we only tune the value for the buckets, not for the raw input token number, to avoid unexpected misalignment between tile_tokens_dim and the token number.

This PR also removes the warmup requests with the extra input shapes, which are triggered in the CUDA graph warmup phase.

Signed-off-by: Yukun He <[email protected]>
@hyukn hyukn deleted the chore/solve_tile_token_dim_calc branch October 31, 2025 09:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants