Skip to content

Conversation

lfr-0531
Copy link
Collaborator

@lfr-0531 lfr-0531 commented Aug 26, 2025

Summary by CodeRabbit

  • Refactor

    • Separated and compiled score computation to improve routing performance and reliability.
    • Simplified attention multiply path on newer GPUs to reduce memory moves and latency.
    • Replaced Python-level MoE preprocessing with a fused GPU kernel to cut Python overhead and boost throughput.
    • Changed FP8 quantization buffer allocation to avoid needless zero-initialization for faster startup.
  • Chores

    • Performance-focused cleanups across GPU execution paths.

Description

  • Change the torch.zeros() in per_token_quant_and_transform to torch.empty()
  • Remove the out.copy_(output) in fp8_block_scaling_bmm_out
  • Enable torch.compile() to fuse the sigmoid + add_bias in Deepseekv3RoutingImpl
  • Add a new Triton kernel to fuse the preprocess_after_permute in fused_moe_deepgemm

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@lfr-0531 lfr-0531 requested review from a team as code owners August 26, 2025 03:13
Copy link
Contributor

coderabbitai bot commented Aug 26, 2025

📝 Walkthrough

Walkthrough

Adds a compiled get_scores helper to Deepseekv3 routing and refactors noaux_tc to call it; simplifies SM100 FP8 bmm to write directly into output; replaces Python MoE preprocessing with a Triton kernel for token→expert mapping and counts; changes an FP8 buffer allocation from zeros to empty.

Changes

Cohort / File(s) Summary
DeepSeekV3 routing refactor
tensorrt_llm/_torch/models/modeling_deepseekv3.py
Added get_scores(self, logits, e_score_correction_bias) annotated with @torch.compile(...) to compute sigmoid(logits) and bias-adjusted scores; noaux_tc now delegates score computation to get_scores and preserves downstream logic (shape/NaN checks and fused/non-fused branches).
Attention FP8 SM100 path
tensorrt_llm/_torch/modules/attention.py
For SM 100, replaces the two-step bmm into a temp plus copy with a single in-place torch.bmm(..., out=out) call; other SM branches unchanged.
Fused MoE preprocessing via Triton
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Removed Python helpers (swiglu_fused_moe, indexing); added Triton _preprocess_after_permute_kernel(...) and updated preprocess_after_permute to launch it, producing token_to_expert_map and tokens_per_expert on GPU; removed related F usage.
FP8 utils allocation tweak
tensorrt_llm/quantization/utils/fp8_utils.py
In per_token_quant_and_transform, changed output_scale allocation from torch.zeros(...) to torch.empty(...) (CUDA int32) to avoid zero-initialization while preserving kernel writes and subsequent usage.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant Deepseekv3RoutingImpl as Deepseekv3RoutingImpl
  participant get_scores as get_scores() [compiled]

  Caller->>Deepseekv3RoutingImpl: noaux_tc(logits, e_score_correction_bias)
  activate Deepseekv3RoutingImpl
  Deepseekv3RoutingImpl->>get_scores: (logits, e_score_correction_bias)
  activate get_scores
  get_scores-->>Deepseekv3RoutingImpl: (scores, scores_with_bias)
  deactivate get_scores
  Deepseekv3RoutingImpl->>Deepseekv3RoutingImpl: shape checks, NaN checks
  alt fused routing
    Deepseekv3RoutingImpl->>Deepseekv3RoutingImpl: fused path
  else non-fused routing
    Deepseekv3RoutingImpl->>Deepseekv3RoutingImpl: standard path
  end
  Deepseekv3RoutingImpl-->>Caller: routing outputs
  deactivate Deepseekv3RoutingImpl
Loading
sequenceDiagram
  autonumber
  actor Pipeline
  participant Preprocess as preprocess_after_permute
  participant Triton as _preprocess_after_permute_kernel

  Pipeline->>Preprocess: expert_offsets
  Preprocess->>Preprocess: allocate masked_m, token_to_expert_map
  Preprocess->>Triton: launch(kernel, TOTAL_TOKENS, NUM_EXPERTS, BLOCK_SIZE_M)
  rect rgba(200,230,255,0.3)
  note right of Triton: Phase 1 — token→expert mapping (write token_map)
  end
  rect rgba(200,255,200,0.3)
  note right of Triton: Phase 2 — per-expert token counts (write masked_m/tokens_per_expert)
  end
  Triton-->>Preprocess: filled masked_m, token_to_expert_map
  Preprocess-->>Pipeline: (masked_m, token_to_expert_map)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

Community want to contribute

Suggested reviewers

  • litaotju
  • Barry-Delaney
  • yuxianq

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/attention.py (1)

1-1: Please add the NVIDIA copyright header to tensorrt_llm/_torch/modules/attention.py

Other Python modules in this directory (for example, rms_norm.py and layer_norm.py) already include the required header, but attention.py does not. Per our coding guidelines, every source file (*.py) under tensorrt_llm/ must begin with:

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

—so please prepend that exact line at the very top of tensorrt_llm/_torch/modules/attention.py.

🧹 Nitpick comments (9)
tensorrt_llm/_torch/modules/attention.py (2)

557-559: SM100 path: assert mat2_dequant is provided before using it.

For SM100 we now depend on the dequantized weights. Add a defensive check to fail fast with a clear message if it’s missing.

-    elif sm_version == 100:
-        torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
+    elif sm_version == 100:
+        assert mat2_dequant is not None, "SM100 path requires mat2_dequant (BF16) to be provided"
+        torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)

544-561: Return value mismatch: function annotated to return Tensor but returns None.

Either return out or change the annotation to -> None. Returning out keeps BC, clarifies intent, and helps static analyzers.

 def fp8_block_scaling_bmm_out(
     mat1: torch.Tensor,
     mat2_fp8: torch.Tensor,
     mat2_scale: torch.Tensor,
     out: torch.Tensor,
     mat2_dequant: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
@@
     else:
         raise NotImplementedError(f"SM{sm_version} is not supported")
 
+    return out
tensorrt_llm/quantization/utils/fp8_utils.py (1)

1-1: Missing NVIDIA copyright header (2025).

Please prepend the NVIDIA header to comply with project standards.

Use the same header check script posted in attention.py.

tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

282-287: Prefer torch.sigmoid and enable dynamic shape compilation to reduce recompiles.

  • F.sigmoid is deprecated; torch.sigmoid is the recommended API.
  • Routing shapes (batch/tokens) can vary; dynamic=True helps TorchDynamo/Inductor avoid recompiling per shape.
-    @torch.compile(options={"max-autotune": True})
+    @torch.compile(dynamic=True, options={"max-autotune": True})
     def get_scores(self, logits, e_score_correction_bias):
-        scores = F.sigmoid(logits)
+        scores = torch.sigmoid(logits)
         scores_with_bias = scores + e_score_correction_bias
         return scores, scores_with_bias

1-1: NVIDIA copyright header (2025) is missing.

This file currently starts with third‑party license text. Per guidelines, prepend the NVIDIA header above (or directly below) while keeping upstream license notices intact.

Use the same header check script posted in attention.py.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)

218-256: Fused kernel is correct; consider algorithmic improvement for many experts.

The token→expert map computes in O(NUM_EXPERTS) per token by scanning boundaries. For large expert counts, a per‑block binary search over expert_offsets_ptr can reduce it to O(log NUM_EXPERTS), improving latency at scale.

If helpful, I can draft a Triton binary‑search variant that preserves your current memory pattern and masks.


265-293: Add dtype/device/contiguity assertions to the Python wrapper for early failure and safer launches.

Small guards make bad inputs fail fast and self‑document the expectations.

 @nvtx_range("[DG] preprocess_after_permute")
 def preprocess_after_permute(expert_first_token_offset_tensor,
                              permuted_data_tensor):
     """
     Python wrapper that launches a single fused kernel to get the token-to-expert map
     and the number of tokens per expert.
     """
+    # Sanity checks
+    assert expert_first_token_offset_tensor.is_cuda, "expert offsets must be CUDA tensor"
+    assert permuted_data_tensor.is_cuda, "permuted data must be CUDA tensor"
+    assert expert_first_token_offset_tensor.dtype in (torch.int32, torch.int64), \
+        f"unexpected dtype for expert offsets: {expert_first_token_offset_tensor.dtype}"
+    assert permuted_data_tensor.is_contiguous(), "permuted_data_tensor must be contiguous"
+    assert expert_first_token_offset_tensor.is_contiguous(), "expert offsets must be contiguous"
+
     total_tokens = permuted_data_tensor.shape[0]
     num_experts = expert_first_token_offset_tensor.shape[0] - 1

265-293: Grid selection logic is sensible; keep a zero‑token guard near call‑site or here.

You already avoid calling this when permuted_data_tensor.numel() == 0. If future callers change, add an early return for total_tokens == 0 to avoid launching a (0, 2) grid.

Do you want me to add that guard and wire a unit test comparing against a Python reference using torch.searchsorted?


1-1: Missing NVIDIA copyright header (2025).

Please prepend the NVIDIA header to this file, per repo policy.

Use the same header check script posted in attention.py.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4f84a45 and bb82e26.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (1 hunks)
  • tensorrt_llm/_torch/modules/attention.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/quantization/utils/fp8_utils.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/quantization/utils/fp8_utils.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (1)
  • n_group (222-222)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
tensorrt_llm/_utils.py (1)
  • nvtx_range (843-862)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/quantization/utils/fp8_utils.py (1)

479-481: Action Required: Manually verify torch.empty usage on a CUDA-enabled environment

  • Using
    output_scale = torch.empty((scale_k_padded // 4, m_padded),
                               dtype=torch.int32,
                               device='cuda')
    is safe because the custom kernel _per_token_quant_and_transform_kernel fully initializes every element of the used slice before we slice to [:m, :].
  • Our sandbox environment failed to import PyTorch (ModuleNotFoundError: No module named 'torch'), so we couldn’t execute the suggested spot-check here.
  • Please run the following on a CUDA machine to ensure there are no uninitialized or non-finite values in output_scale:
#!/bin/bash
python - <<'PY'
import torch
from tensorrt_llm.quantization.utils.fp8_utils import per_token_quant_and_transform

torch.manual_seed(0)
x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16).contiguous()
_, sf = per_token_quant_and_transform(x, quant_group_size=128, scale_ue8m0=True)
assert torch.isfinite(sf).all(), "Found non-finite in output_scale"
print("ok:", sf.shape, sf.dtype, sf.stride())
PY
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

288-297: Please verify graph caching and dtype promotion

Extraction and compilation of the score computation look good. Keep an eye on dtype promotion: e_score_correction_bias is bfloat16 while scores is fp32, so PyTorch will upcast to fp32. If you change the dtype of logits later, double-check that the result is still what you expect.

To confirm that only a single compiled graph is cached across varying token counts, please rerun the following profiling script in an environment where PyTorch (with CUDA) is installed and accessible:

#!/bin/bash
python - <<'PY'
import torch, time
from tensorrt_llm._torch.models.modeling_deepseekv3 import Deepseekv3RoutingImpl

# Initialize routing implementation
impl = Deepseekv3RoutingImpl(
    top_k=2,
    n_group=8,
    topk_group=2,
    routed_scaling_factor=1.0,
    is_fused=False
)

# Test two different sequence lengths
for seq_len in (512, 768):
    logits = torch.randn(seq_len, 256, device="cuda", dtype=torch.float32)
    bias   = torch.zeros(256,     device="cuda", dtype=torch.bfloat16)
    t0 = time.time()
    _ = impl.noaux_tc(logits, bias)
    torch.cuda.synchronize()
    print(f"Ran shape {seq_len} in", time.time() - t0)
PY

Verify from the timing (or debug logs) that the second run does not include additional compilation overhead. If you see two separate compilation steps, we may need to investigate how inputs are traced or consider unifying dynamic dimensions.

Signed-off-by: Fanrong Li <[email protected]>
@lfr-0531 lfr-0531 force-pushed the user/fanrongl/fix_element_wise branch from 4cee276 to 31687b0 Compare August 26, 2025 06:30
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (7)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (7)

240-242: Use int32 for token_to_expert_map to cut bandwidth and match downstream kernels

The map fits in int32 (num_experts <= 2^31-1), and both masked_index_copy_group_quant_fp8 and masked_index_gather_kernel operate on int32 efficiently. This avoids unnecessary 64-bit traffic and register pressure.

Apply:

-        tl.store(token_map_ptr + token_offsets,
-                 expert_ids.to(tl.int64),
-                 mask=token_mask)
+        tl.store(token_map_ptr + token_offsets,
+                 expert_ids,  # int32 is sufficient and matches consumers
+                 mask=token_mask)
-    token_to_expert_map = torch.empty(total_tokens,
-                                      dtype=torch.int64,
-                                      device='cuda')
+    token_to_expert_map = torch.empty(total_tokens,
+                                      dtype=torch.int32,
+                                      device='cuda')

Also applies to: 270-272


223-226: Avoid excessive kernel specialization: make TOTAL_TOKENS a runtime arg (not tl.constexpr)

TOTAL_TOKENS is only used for a mask; keeping it constexpr forces a new compilation per sequence length/batch and increases JIT cache churn. NUM_EXPERTS and BLOCK_SIZE_M can stay constexpr.

Apply:

 def _preprocess_after_permute_kernel(
     expert_offsets_ptr,
     masked_m_ptr,
     token_map_ptr,
-    TOTAL_TOKENS: tl.constexpr,
+    total_tokens,
     NUM_EXPERTS: tl.constexpr,
     BLOCK_SIZE_M: tl.constexpr,
 ):
@@
-        token_mask = token_offsets < TOTAL_TOKENS
+        token_mask = token_offsets < total_tokens
@@
     _preprocess_after_permute_kernel[grid](
         expert_first_token_offset_tensor,
         masked_m,
         token_to_expert_map,
-        TOTAL_TOKENS=total_tokens,
+        total_tokens=total_tokens,
         NUM_EXPERTS=num_experts,
         BLOCK_SIZE_M=BLOCK_SIZE_M,
     )

Also applies to: 232-232, 290-293


235-240: O(NUM_EXPERTS) scan per token block; consider binary search on offsets

The linear scan across NUM_EXPERTS per block is simple but scales poorly when experts are large (e.g., 128–256). A branchless lower_bound-style binary search over expert_offsets (exclusive prefix sums) would reduce this to O(log NUM_EXPERTS) and cut global loads. This can be done with a small loop over bit widths using vectorized tl.load on the offsets.

If you’d like, I can sketch a branchless Triton lower_bound for the offsets array.


258-266: Mark wrapper no-grad and document shapes/dtypes (Google-style docstring)

This path is inference-only; disabling autograd avoids overhead. Also, add explicit shape/dtype docs to ease maintenance.

Apply:

-@nvtx_range("[DG] preprocess_after_permute")
+@torch.no_grad()
+@nvtx_range("[DG] preprocess_after_permute")
 def preprocess_after_permute(expert_first_token_offset_tensor,
                              permuted_data_tensor):
-    """
-    Python wrapper that launches a single fused kernel to get the token-to-expert map
-    and the number of tokens per expert.
-    """
+    """
+    Launch a fused Triton kernel to compute:
+      - token_to_expert_map: token index -> expert id
+      - masked_m: number of tokens per expert
+
+    Args:
+        expert_first_token_offset_tensor (torch.Tensor):
+            1D int32/int64 CUDA tensor of length num_experts + 1 (exclusive prefix sums).
+            offsets[-1] must equal total_tokens.
+        permuted_data_tensor (torch.Tensor):
+            CUDA tensor; only shape[0] (total_tokens) is used.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            masked_m (int32[num_experts]), token_to_expert_map (int32[total_tokens]).
+    """

265-273: Add cheap sanity checks to fail fast on device/dtype mismatches

These assertions guard common integration errors without affecting hot paths.

Apply:

     total_tokens = permuted_data_tensor.shape[0]
     num_experts = expert_first_token_offset_tensor.shape[0] - 1
+    # Sanity checks
+    assert expert_first_token_offset_tensor.is_cuda and permuted_data_tensor.is_cuda, \
+        "Inputs must be CUDA tensors"
+    assert expert_first_token_offset_tensor.ndim == 1, \
+        "expert_first_token_offset_tensor must be 1D"
+    assert expert_first_token_offset_tensor.numel() == num_experts + 1, \
+        "Offsets length must be num_experts + 1"
+    assert expert_first_token_offset_tensor.dtype in (torch.int32, torch.int64), \
+        "Offsets dtype must be int32 or int64"

275-284: Grid Y=1 phase may launch many idle programs when grid_m_size >> num_experts

The fused 2D grid is nice to save a launch, but for small expert counts and large token counts, the “counts” phase (pid_y == 1) will run with grid_x = grid_m_size (often >> num_experts) and most programs will immediately mask out. If this shows up in profiles, consider either:

  • Launching a tiny second kernel just for counts (one-time per batch, negligible launch cost).
  • Or keep the fused kernel but set grid_x = max(grid_m_size, num_experts) only when pid_y == 0 via two launches of the same kernel with different grids.

Happy to provide a minimal two-launch variant if you want to compare in Nsight.


1-1: License header

This Python file appears to be missing the NVIDIA copyright header. If this repo enforces it globally, please add it.

I can prep a small PR-wide patch to add/refresh headers if useful.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bb82e26 and 31687b0.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (1 hunks)
  • tensorrt_llm/_torch/modules/attention.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/quantization/utils/fp8_utils.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧠 Learnings (2)
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
tensorrt_llm/_utils.py (1)
  • nvtx_range (843-862)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

275-283: Confirm Triton helper availability across CI images

The sandbox lacks Triton, so we couldn’t automatically verify whether triton.next_power_of_2 is provided by all wheels. Please verify in your CI and test environments:

  • Ensure hasattr(triton, "next_power_of_2") returns True.

  • If any Triton wheel omits next_power_of_2, inline the following helper:

    def _next_pow2(x: int) -> int:
        return 1 if x <= 1 else 1 << ((x - 1).bit_length())
    
    block_size_m = triton.cdiv(total_tokens, num_experts)
    BLOCK_SIZE_M = _next_pow2(block_size_m)

@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16530 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16530 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12414 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@lfr-0531 lfr-0531 merged commit e12868b into NVIDIA:main Aug 27, 2025
5 checks passed
@lfr-0531 lfr-0531 deleted the user/fanrongl/fix_element_wise branch September 22, 2025 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants