Skip to content

Conversation

litaotju
Copy link
Collaborator

@litaotju litaotju commented Jul 16, 2025

PR title

Please write the PR title by following template:

[JIRA ticket link/nvbug link/github issue link][fix/feat/doc/infra/...] <summary of this PR>

For example, assume I have a PR hope to support a new feature about cache manager of Jira TRTLLM-1000 ticket, it would be like

[TRTLLM-1000][feat] Support a new feature about cache manager

Description

Please explain the issue and the solution in short.

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features

    • Added support for the "DEEPGEMM" backend with a new fused Mixture of Experts (MoE) implementation optimized for FP8 precision and Blackwell GPUs.
    • Introduced FP8 quantization utilities with per-token and per-block casting, resmoothing, and enhanced FP8 weight handling including dequantization and scaling factor layout transformations.
    • Added new command-line options to select additional MoE backends.
    • Improved FP8 batched matrix multiplication with a new dequantized path for SM 100 hardware.
    • Added synchronization barrier to ensure coordinated weight loading across local MPI ranks.
    • Extended LLaMA example documentation with instructions for running the LLaMA 3.3 70B FP8 model on the PyTorch backend.
  • Bug Fixes

    • Improved tensor memory handling for transposed weights to ensure contiguity.
  • Chores

    • Added new dependencies required for the DeepGemm backend.
    • Exposed FP8 utility modules for broader usage.
    • Added local MPI communicator and barrier utilities.
    • Updated Jenkins scripts to improve rerun test reporting and stage name handling.
  • Tests

    • Introduced comprehensive tests for FP8 blockwise quantization and DeepGemm-based MoE modules.
    • Added unit tests for new FP8 casting utilities and matrix multiplication routines.

@litaotju litaotju requested a review from a team as a code owner July 16, 2025 15:34
Copy link
Contributor

coderabbitai bot commented Jul 16, 2025

Walkthrough

The changes introduce a new DeepGemm-based FP8 blockwise Mixture-of-Experts (MoE) backend, including its implementation, quantization utilities, and integration into the MoE factory. FP8 e8m0 quantization and resmoothing functions are added, and modules are updated to support the new backend and quantization path. Comprehensive unit tests for DeepGemm FP8 workflows are included. Additional updates improve FP8 weight handling, add MPI synchronization after prefetching, and extend CLI options.

Changes

File(s) Change Summary
examples/llm-api/quickstart_advanced.py Expanded --moe_backend CLI argument to allow 'DEEPGEMM' and 'CUTEDSL' as valid options.
requirements.txt Added deep_gemm dependency from GitHub.
tensorrt_llm/_torch/models/modeling_deepseekv3.py Added FP8 resmoothing logic; ensured tensor contiguity after transpose; dequantized and copied attention projection weights.
tensorrt_llm/_torch/modules/attention.py Added dequantization params to MLA; updated FP8 path for SM100 to use dequantized matmul; updated fp8_block_scaling_bmm_out.
tensorrt_llm/_torch/modules/fused_moe/create_moe.py Integrated DeepGemmFusedMoE as a selectable MoE backend.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py New: Implemented DeepGemmFusedMoE class and supporting FP8/activation ops for fused MoE with DeepGemm backend.
tensorrt_llm/_torch/modules/linear.py Added DeepGemm FP8 GEMM path for SM100 in FP8BlockScalesLinearMethod.apply; reordered weight copying in weight loading.
tensorrt_llm/quantization/utils/init.py Exposed new fp8_utils module in package exports.
tensorrt_llm/quantization/utils/fp8_utils.py New: Added FP8 e8m0 quantization, per-token/block casting, and resmoothing utilities.
tensorrt_llm/_torch/modules/fused_moe/quantization.py Added load_weights method with FP8 weight resmoothing for SM100 in DeepSeekFP8BlockScalesFusedMoEMethod.
tensorrt_llm/_torch/pyexecutor/model_engine.py Added MPI barrier synchronization after file prefetching in prefetch_files.
tensorrt_llm/_utils.py Added local_mpi_comm() and local_mpi_barrier() functions for local MPI communicator access and synchronization.
tests/unittest/_torch/helpers.py Added FP8 e8m0 casting helpers and alignment utilities for tests.
tests/unittest/_torch/modules/test_fused_moe.py Added comprehensive test for DeepGemmFusedMoE with FP8 blockwise quantization and reference validation.
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py Added test for DeepGemm FP8 block scale GEMM on Blackwell (SM100) architecture.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant ModelConfig
    participant MoEFactory
    participant DeepGemmFusedMoE
    participant QuantUtils
    participant CUDA

    User->>ModelConfig: Set moe_backend='DEEPGEMM'
    ModelConfig->>MoEFactory: create_moe()
    MoEFactory->>DeepGemmFusedMoE: Instantiate with config
    DeepGemmFusedMoE->>QuantUtils: per_token_cast_to_fp8_e8m0(input)
    DeepGemmFusedMoE->>QuantUtils: per_block_cast_to_fp8_e8m0(weights)
    DeepGemmFusedMoE->>CUDA: deep_gemm.fp8_gemm_nt(...)
    DeepGemmFusedMoE-->>User: Return output tensor
Loading

Estimated code review effort

File(s) Effort Score
examples/llm-api/quickstart_advanced.py 🎯 1 (Trivial)
requirements.txt 🎯 1 (Trivial)
tensorrt_llm/_torch/models/modeling_deepseekv3.py 🎯 3 (Moderate)
tensorrt_llm/_torch/modules/attention.py 🎯 3 (Moderate)
tensorrt_llm/_torch/modules/fused_moe/create_moe.py 🎯 2 (Simple)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py 🎯 5 (Critical)
tensorrt_llm/_torch/modules/linear.py 🎯 3 (Moderate)
tensorrt_llm/quantization/utils/init.py 🎯 1 (Trivial)
tensorrt_llm/quantization/utils/fp8_utils.py 🎯 4 (Complex)
tensorrt_llm/_torch/modules/fused_moe/quantization.py 🎯 3 (Moderate)
tensorrt_llm/_torch/pyexecutor/model_engine.py 🎯 2 (Simple)
tensorrt_llm/_utils.py 🎯 2 (Simple)
tests/unittest/_torch/helpers.py 🎯 2 (Simple)
tests/unittest/_torch/modules/test_fused_moe.py 🎯 3 (Moderate)
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py 🎯 2 (Simple)

Suggested reviewers

  • litaotju
  • hyukn
  • yilin-void
  • brb-nv
  • suyoggupta

Poem

(\(\
( -.-) A hop, a leap, a DeepGemm bound,
((")(") FP8 blocks and experts found!
With e8m0 scales and CUDA speed,
MoE flows fast—just what we need.
Tests and helpers join the show—
Blackwell bunnies, off we go!

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 77e179c and 2c908f3.

📒 Files selected for processing (1)
  • requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • requirements.txt
✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
tensorrt_llm/quantization/utils/fp8_utils.py (2)

22-23: Add documentation for consistency.

Other functions in this module have docstrings. Consider adding documentation for the align function.

 def align(x: int, y: int) -> int:
+    """
+    Align x to the nearest multiple of y.
+    
+    Args:
+        x: The value to align.
+        y: The alignment boundary.
+    
+    Returns:
+        The smallest multiple of y that is >= x.
+    """
     return ceil_div(x, y) * y

26-28: Add documentation for the e8m0 utility function.

This function is a key utility for e8m0 quantization. Consider adding documentation to explain its purpose.

 def ceil_to_ue8m0(x: torch.Tensor):
+    """
+    Compute the ceiling to the nearest power of 2 for tensor elements.
+    
+    This is used for e8m0 quantization where scale factors must be powers of 2.
+    
+    Args:
+        x: Input tensor.
+    
+    Returns:
+        Tensor with each element rounded up to the nearest power of 2.
+    """
     return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
tensorrt_llm/_torch/modules/attention.py (1)

710-726: Consider memory optimization for dequantized parameters.

Creating separate dequantized parameters doubles memory usage for k_b_proj_trans and v_b_proj. For large models, this could be significant. Consider:

  • Dequantizing on-the-fly during forward pass instead of storing dequantized copies
  • Using lazy initialization to create these tensors only when needed
  • Adding a comment explaining the memory trade-off
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

30-74: Add documentation for the complex padding logic.

The copy_after function implements intricate token padding and rearrangement logic. Please add a docstring explaining:

  • The purpose of padding to 128-token boundaries
  • The meaning of input/output tensors
  • The algorithm's key steps

This will improve maintainability and help future developers understand the optimization strategy.


76-91: Consider making output dtype configurable.

The function hardcodes the output dtype as torch.bfloat16. This might limit flexibility for different precision requirements. Consider:

  • Adding an optional dtype parameter
  • Or documenting why bfloat16 is specifically required
 def deepgemm_fp8_group_blockwise_gemm_ref(
     a: torch.Tensor,
     b: torch.Tensor,
     a_sf: torch.Tensor,
     b_sf: torch.Tensor,
     m_indices: torch.Tensor,
+    dtype: torch.dtype = torch.bfloat16,
 ) -> torch.Tensor:
 
     d = torch.empty((a.shape[0], b.shape[1]),
                     device=b.device,
-                    dtype=torch.bfloat16)
+                    dtype=dtype)
     deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d,
                                                m_indices)
     return d

174-175: Fix line length violation.

Line 175 exceeds the 120-character limit. Please split it across multiple lines.

-            assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
+            assert self.routing_method.top_k == 1, \
+                "Current workaround only supports top-1 routing"
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e30d7be and a2fb8e5.

📒 Files selected for processing (12)
  • examples/llm-api/quickstart_advanced.py (1 hunks)
  • requirements.txt (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (5 hunks)
  • tensorrt_llm/_torch/modules/attention.py (6 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/_torch/modules/linear.py (3 hunks)
  • tensorrt_llm/quantization/utils/__init__.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
  • tests/unittest/_torch/helpers.py (2 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (3 hunks)
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)
tests/unittest/_torch/helpers.py (4)
  • calc_diff (71-75)
  • per_block_cast_to_fp8 (28-41)
  • per_block_cast_to_fp8_e8m0 (55-68)
  • per_token_cast_to_fp8_e8m0 (44-52)
tensorrt_llm/quantization/utils/fp8_utils.py (2)
  • per_block_cast_to_fp8_e8m0 (42-67)
  • per_token_cast_to_fp8_e8m0 (31-39)
tensorrt_llm/_torch/modules/linear.py (2)
tensorrt_llm/quantization/utils/fp8_utils.py (1)
  • per_token_cast_to_fp8_e8m0 (31-39)
tensorrt_llm/_utils.py (3)
  • get_sm_version (648-650)
  • shape (905-906)
  • shape (922-923)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (5)
tensorrt_llm/quantization/utils/fp8_utils.py (1)
  • resmooth_to_fp8_e8m0 (70-80)
tensorrt_llm/models/modeling_utils.py (2)
  • layer_quant_mode (161-165)
  • layer_quant_mode (301-307)
tensorrt_llm/_torch/modules/linear.py (1)
  • has_fp8_block_scales (1147-1150)
tests/integration/defs/conftest.py (1)
  • get_sm_version (1857-1860)
tensorrt_llm/_utils.py (3)
  • get_sm_version (648-650)
  • shape (905-906)
  • shape (922-923)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

175-175: Line too long (131 > 120)

(E501)

🔇 Additional comments (28)
tensorrt_llm/quantization/utils/__init__.py (1)

1-3: LGTM: Proper module exposure following existing patterns.

The addition of fp8_utils to the imports and __all__ list follows the established pattern and properly exposes the new FP8 quantization utilities for public use.

examples/llm-api/quickstart_advanced.py (1)

50-54: LGTM: Proper CLI argument expansion for new MoE backends.

The addition of 'DEEPGEMM' and 'CUTEDSL' to the --moe_backend choices correctly exposes the new MoE backend implementations while maintaining backward compatibility.

tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)

27-61: LGTM: Comprehensive test for DeepGemm FP8 GEMM implementation.

The test function properly:

  • Targets the correct hardware (SM 100/Blackwell) with appropriate skip conditions
  • Tests multiple matrix dimensions with comprehensive parameterization
  • Uses the new FP8 casting utilities correctly
  • Compares against a reference implementation with appropriate tolerance

The 1e-2 tolerance threshold is reasonable for FP8 precision testing.


21-23: Helper imports are available

I confirmed that tests/unittest/_torch/helpers.py defines all imported functions—calc_diff, per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, and per_token_cast_to_fp8_e8m0—so the from _torch.helpers statements in test_fp8_block_scale_gemm.py are valid. No changes are needed.

tensorrt_llm/_torch/modules/linear.py (4)

22-24: LGTM: Proper imports for FP8 utilities and SM version detection.

The imports are correctly added to support the new FP8 casting functionality and hardware-specific optimizations.


546-559: LGTM: Well-implemented hardware-specific optimization for SM 100.

The conditional branch properly:

  • Checks for SM 100 hardware before using DeepGemm
  • Imports deep_gemm only when needed to avoid import errors on other hardware
  • Uses the correct FP8 casting function for the DeepGemm path
  • Maintains the same output tensor structure as the original implementation

The fallback to the existing torch.ops.trtllm path ensures backward compatibility.


597-598: No dependency on weight vs. weight_scale loading order detected
After searching through all load_weights_* implementations, there’s no logic that reads one before the other—these two calls simply assign independent buffers. The reordering in load_weights_fused_qkv_linear won’t affect any downstream functionality.


612-613: No order-dependent logic found in FUSED_GATE_UP_LINEAR loader

I searched for any references to the order of fused weight vs. fused scale loading and confirmed that in tensorrt_llm/_torch/modules/linear.py the method
load_weights_fused_gate_up_linear always does:

copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)

– exactly mirroring the FUSED_QKV_LINEAR implementation. There are no other code paths or initialization routines that assume a different sequence. This change should not break existing functionality.

tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3)

11-11: LGTM!

The import follows the established pattern for MoE backend modules.


35-36: LGTM!

The DEEPGEMM backend integration follows the existing pattern for MoE backend selection.


145-158: LGTM!

The DeepGemmFusedMoE instantiation correctly follows the pattern used by CutlassFusedMoE and other similar backends, passing all required parameters.

tests/unittest/_torch/helpers.py (4)

11-13: LGTM!

The align function correctly implements alignment to the nearest multiple using ceiling division.


15-17: LGTM!

The function correctly computes the ceiling to the nearest power of 2, which is appropriate for e8m0 scaling factor computation.


44-53: LGTM!

The per-token FP8 e8m0 casting correctly implements power-of-2 scaling factors, which is the key difference from standard FP8 quantization.


55-69: LGTM!

The per-block FP8 e8m0 casting correctly implements block-wise quantization with power-of-2 scaling factors and proper padding.

tests/unittest/_torch/modules/test_fused_moe.py (2)

12-13: LGTM!

The imports are properly organized and follow the existing pattern.

Also applies to: 29-30


385-551: LGTM!

The test comprehensively validates the DeepGemmFusedMoE implementation with FP8 blockwise quantization:

  • Correctly uses FP8 e8m0 quantization for both weights and activations
  • Implements a detailed reference implementation for accurate comparison
  • Tests multiple sequence lengths for robustness
  • Properly restricted to Blackwell+ GPUs with the decorator
tensorrt_llm/_torch/models/modeling_deepseekv3.py (4)

46-46: LGTM!

The import is properly placed and follows the module organization pattern.


1295-1304: LGTM!

The FP8 resmoothing logic correctly applies e8m0 quantization for SM 100 hardware when using FP8 block scales.


1211-1211: LGTM!

Adding .contiguous() after transpose operations ensures proper memory layout for subsequent operations.

Also applies to: 1252-1252


1356-1376: LGTM!

The dequantization logic correctly handles k_b_proj_trans and v_b_proj weights:

  • Properly checks for parameter existence
  • Correctly dequantizes using the weight_dequant function
  • Ensures proper tensor shapes and dtypes
tensorrt_llm/quantization/utils/fp8_utils.py (4)

8-20: LGTM!

The ceiling division implementation is correct and well-documented.


30-41: LGTM!

The per-token FP8 e8m0 casting is correctly implemented with:

  • Proper input validation
  • Power-of-2 scaling factor computation
  • Performance optimization with nvtx_range and torch.compile

43-69: LGTM!

The per-block FP8 e8m0 casting correctly handles both 2D and 3D tensors with proper padding and block-wise quantization.


71-82: LGTM!

The resmoothing function correctly converts FP8 weights to e8m0 format by:

  • Dequantizing with existing scales
  • Re-quantizing with power-of-2 scaling factors
  • Supporting both 2D and 3D weight tensors
tensorrt_llm/_torch/modules/attention.py (2)

368-395: Clarify the temporary FP8 GEMM workaround for SM 100.

The implementation replaces FP8 GEMM with regular torch.bmm using dequantized matrices on SM 100 hardware. The commented-out code suggests this is temporary. Please clarify:

  • Is this due to FP8 GEMM support limitations on SM 100?
  • What is the performance impact of using dequantized GEMM instead of FP8 GEMM?
  • When will the FP8 GEMM implementation be restored?

Consider adding a TODO comment explaining the rationale and timeline for restoring the FP8 GEMM path.


1222-1228: LGTM! Correct integration of dequantized parameters.

The dequantized parameters are correctly passed to fp8_block_scaling_bmm_out in both BMM operations, maintaining consistency with the SM 100 workaround.

Also applies to: 1277-1283

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

183-189: Consider documenting quantization mode limitations.

The implementation only supports deepseek_fp8_block_scales quantization mode. Please add a comment explaining:

  • Why only this mode is supported
  • Plans for supporting other quantization modes
  • Whether this is a fundamental limitation of the DeepGemm backend

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1362-1381: Review dequantization logic for potential issues

The dequantization logic for k_b_proj_trans_dequant and v_b_proj_dequant parameters looks functionally correct, but there are several considerations:

  1. CUDA tensor usage: The code correctly moves tensors to CUDA for the weight_dequant function
  2. Shape handling: The view operations and reshaping appear correct
  3. Data type conversion: The final .to(dtype) conversion is appropriate

However, there are potential concerns:

  • Memory efficiency: Creating intermediate CUDA tensors and then converting back might be memory-intensive
  • Error handling: No validation that the dequantization parameters exist before use
  • Performance: Multiple view/reshape operations could be optimized

Consider adding null checks before dequantization:

-                        if attn_module.k_b_proj_trans_dequant is not None:
+                        if (attn_module.k_b_proj_trans_dequant is not None and 
+                            k_b_proj_trans_scale is not None):

Also consider batching the CUDA operations to improve memory efficiency.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2fb8e5 and 0286391.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (5 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (2 hunks)
🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)

7-7: Import addition looks good.

The logger import is appropriately added to support the debug logging in the new load_weights method.


12-12: Import addition is correct.

The resmooth_to_fp8_e8m0 function import is properly added to support FP8 resmoothing functionality.


468-481: In-place mutation of weights is consistent and safe
Other load_weights implementations across the codebase routinely modify their local weights dict in-place (e.g. in fused MoE backends, linear/quantization helpers, and model converters). In this context, each dict is created fresh for a single module’s loading, so mutating it does not leak side effects elsewhere. No copy is necessary.

tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)

47-47: LGTM: Import addition for FP8 resmoothing utility

The import of resmooth_to_fp8_e8m0 is correctly added to support the new FP8 resmoothing functionality implemented below.


1212-1212: Good: Adding contiguous() for memory layout optimization

The .contiguous() call ensures proper memory layout for the transposed tensor, which is important for performance in subsequent operations.


1253-1253: Good: Consistent contiguous() application

Same optimization applied consistently in the dequantization helper function.

Comment on lines +1296 to +1300
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_version() == 100:
for name in list(weights.keys()):
# Use ".experts." to exclude shared_experts.
if name.endswith(
"weight_scale_inv") and ".experts." not in name:
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
weights[weight_name] = weights[weight_name].cpu()
weights[name] = weights[name].cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify FP8 resmoothing integration and performance impact

The FP8 resmoothing logic looks correct but has some considerations:

  1. Architecture check: The condition get_sm_version() == 100 correctly targets SM 100 GPUs
  2. Weight filtering: The logic properly excludes shared experts using ".experts." not in name
  3. Memory management: Weights are moved to CPU after resmoothing, which is good for memory efficiency

However, there are potential concerns:

  • The resmoothing process may be computationally expensive and could impact loading time
  • The CPU transfer after resmoothing might not be necessary if the weights are immediately used

🏁 Script executed:

#!/bin/bash
# Description: Verify the usage of resmooth_to_fp8_e8m0 function and check for any performance implications

# Check if resmooth_to_fp8_e8m0 is properly implemented
ast-grep --pattern 'def resmooth_to_fp8_e8m0($_) {
  $$$
}'

# Check for other usages of this function in the codebase
rg -A 5 "resmooth_to_fp8_e8m0"

# Check if there are any tests for this functionality
rg -A 10 "test.*resmooth.*fp8" 

Length of output: 3138


Add unit tests for FP8 resmoothing & benchmark loading overhead

The FP8 resmoothing logic is correctly integrated—SM 100 check, expert filtering, and GPU→CPU transfer (to free GPU memory during load) are all intentional. However:

  • No unit tests cover resmooth_to_fp8_e8m0
  • The impact on model‐loading time isn’t quantified

Please address the following:

• tests/quantization/test_fp8_utils.py (or equivalent): add tests for resmooth_to_fp8_e8m0 covering
– 2D and higher-dim tensors
– corner cases (e.g. zero scales, non-contiguous tensors)
• Benchmark script or CI performance test: measure load times with and without FP8 resmoothing to ensure overhead is acceptable

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py around lines 1296 to 1309,
the FP8 resmoothing logic lacks unit tests and performance benchmarks. Add unit
tests for the resmooth_to_fp8_e8m0 function in
tests/quantization/test_fp8_utils.py that cover 2D and higher-dimensional
tensors, zero scale values, and non-contiguous tensors. Additionally, create a
benchmark script or CI performance test to measure model loading times with and
without FP8 resmoothing to evaluate and ensure the overhead is acceptable.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)

57-57: Replace lambda expression with a proper function definition.

Following PEP 8 guidelines, lambda expressions should not be assigned to variables.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

107-107: Replace lambda expression with a proper function definition.

Following PEP 8 guidelines, lambda expressions should not be assigned to variables.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

239-239: Split long line for better readability.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )

295-295: Split long line for better readability.

-        expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition
+        expected_m = (
+            (token_selected_experts.numel() + self.expert_size_per_partition - 1) 
+            // self.expert_size_per_partition
+        )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0286391 and d8bac92.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

57-57: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


107-107: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


239-239: Line too long (131 > 120)

(E501)


295-295: Line too long (124 > 120)

(E501)

🔇 Additional comments (9)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (9)

1-17: Import structure looks good.

The imports are well-organized and all dependencies appear to be used appropriately in the implementation. The external deep_gemm dependency and internal FP8 utilities are correctly imported for the DeepGemm FP8 MoE backend.


19-66: Triton kernel implementation looks correct.

The masked_index_copy_kernel and its wrapper function correctly implement masked copying based on expert token offsets. The indexing logic and memory access patterns are appropriate for the MoE use case.


68-116: Triton gather kernel implementation is sound.

The masked_index_gather_kernel and wrapper function correctly implement the inverse operation of the copy kernel. The tensor shape assertions provide good validation.


118-141: Helper functions are well-implemented.

The helper functions correctly implement:

  • SwiGLU activation with proper chunking and gating
  • Masked indexing for token selection
  • Metadata computation for expert token offsets

Good use of @torch.compile and @nvtx_range for optimization and profiling.


143-159: DeepGemm wrapper function is correctly implemented.

The function provides a clean interface to the deep_gemm.fp8_m_grouped_gemm_nt_masked operation with proper tensor creation and argument passing. The bfloat16 output dtype is appropriate for the FP8 workflow.


161-210: Class definition and constructor are well-structured.

The DeepGemmFusedMoE class properly inherits from CutlassFusedMoE and delegates initialization to the parent class. The docstring clearly explains the backend components and custom operations used.


227-243: Routing and input processing logic is correct.

The routing application, expert selection validation, and conditional router weight application are properly implemented with appropriate assertions and error handling.


304-323: FP8 GEMM workflow is well-implemented.

The two-stage FP8 GEMM operations with SwiGLU activation in between correctly implement the MoE expert computation using the new FP8 utilities and DeepGemm backend.


325-347: Output gathering and finalization are correct.

The final steps properly gather results from the padded tensor and use the custom finalization operation to produce the final hidden states with correct scaling and unpermutation.

@lfr-0531 lfr-0531 requested a review from a team as a code owner July 18, 2025 07:57
@lfr-0531 lfr-0531 requested a review from Naveassaf July 18, 2025 07:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

328-348: Address the duplicate variable creation issue.

This code segment has the exact issue mentioned in the past review comments. The variable permuted_data_tensor_padded is created twice with different parameters, causing the first allocation to be immediately overwritten.

Apply this diff to remove the duplicate allocation:

-        max_padded_tokens = (x.shape[0] + 128) // 128 * 128
-        permuted_data_tensor_padded = torch.empty(
-            (self.expert_size_per_partition, max_padded_tokens,
-             self.hidden_size),
-            dtype=self.dtype,
-            device='cuda')
-
         masked_m, start_offsets = preprocess_after_permute(
             expert_first_token_offset_tensor)
         m_max = (x.shape[0] + 127) // 128 * 128
         expected_m = (token_selected_experts.numel() +
                       self.expert_size_per_partition -
                       1) // self.expert_size_per_partition
         permuted_data_tensor_padded = torch.empty(
             self.expert_size_per_partition,
             m_max,
             self.hidden_size,
             dtype=self.dtype,
             device='cuda')
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)

59-59: Replace lambda expressions with proper function definitions.

Static analysis correctly identifies that lambda expressions assigned to variables should be rewritten as proper function definitions for better readability and debugging.

Apply this diff:

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

And similarly for line 109:

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

Also applies to: 109-109


283-283: Fix line length to comply with style guidelines.

Line exceeds the 120-character limit.

Apply this diff:

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 15eaace and bcb15bb.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

59-59: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


109-109: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


283-283: Line too long (131 > 120)

(E501)

🔇 Additional comments (13)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (5)

21-46: LGTM: Well-structured Triton kernel for masked index copy.

The kernel logic correctly handles memory access patterns with proper masking and bounds checking. The computation of indices and offsets is mathematically sound.


70-95: LGTM: Consistent Triton kernel implementation for masked index gather.

The gather kernel mirrors the copy kernel structure appropriately, with correct index calculations and memory access patterns.


120-124: LGTM: Efficient SwiGLU implementation.

The fused SwiGLU activation using torch.compile is well-implemented and should provide good performance.


145-203: LGTM: Well-structured DeepGemm FP8 group blockwise GEMM wrapper.

The function properly validates tensor shapes, strides, and data types. The transformation of scaling factors and the call to the underlying DeepGemm implementation are correctly structured.


350-372: LGTM: Efficient FP8 quantization and GEMM workflow.

The forward pass correctly applies FP8 quantization using the utility functions, performs the two GEMM operations with the fused SwiGLU activation in between, and handles the data flow properly.

tensorrt_llm/quantization/utils/fp8_utils.py (8)

8-19: LGTM: Standard ceiling division implementation.

The mathematical logic is correct and the function is properly documented.


22-28: LGTM: Utility functions are mathematically sound.

Both align and ceil_to_ue8m0 functions implement correct mathematical operations for alignment and power-of-two ceiling calculations.


32-49: LGTM: Well-implemented per-token FP8 quantization.

The function correctly handles both 2D and 3D tensors, applies proper blocking (128 elements), computes scaling factors using the FP8 range (448.0), and performs the quantization with appropriate clamping.


52-77: LGTM: Proper per-block FP8 quantization implementation.

The function correctly:

  • Pads tensors to multiples of 128 for proper block alignment
  • Reshapes into 128x128 blocks for quantization
  • Computes scaling factors per block
  • Returns tensors cropped to original dimensions

80-90: LGTM: Effective weight resmoothing implementation.

The function properly applies scaling factors by repeating them to match weight dimensions and calls the appropriate quantization function.


104-133: LGTM: Complex but correct TMA-aligned tensor transformation.

The function correctly:

  • Converts float tensors to uint8 by extracting exponent bits
  • Handles proper padding and alignment for TMA requirements
  • Performs necessary transposition for column-major layout
  • Supports both 2D and 3D tensors

136-163: LGTM: Comprehensive scaling factor layout validation.

The function performs thorough validation of tensor shapes, strides, and alignment requirements for TMA operations.


168-216: LGTM: Robust scaling factor layout transformation.

The function handles multiple granularity cases and properly transforms scaling factors based on hardware requirements. The logic for different granularities (1,128) and (128,128) is correct.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)

261-401: Well-orchestrated MoE forward pass implementation.

The forward_chunk method properly handles the complex flow of routing, quantization, permutation, FP8 operations, and finalization. The logic appears correct and the duplicate tensor creation issue from previous reviews has been resolved.


288-288: Fix line length violation.

The line exceeds the 120-character limit and should be broken up for better readability.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cd547c7 and b053342.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

60-60: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


112-112: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


288-288: Line too long (131 > 120)

(E501)

🔇 Additional comments (8)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)

1-19: Well-organized imports for FP8 MoE implementation.

The imports are logically grouped and all appear necessary for the DeepGemm FP8 backend functionality.


21-69: Triton kernels are well-implemented with proper bounds checking.

The masked index copy kernel correctly handles 3D tensor operations with expert-based indexing and includes appropriate masking for safety.


60-60: Static analysis false positive - lambda is appropriate here.

The lambda expression for grid is the standard pattern for Triton kernel launches and should not be changed to a function definition.


72-122: Gather kernel correctly implements the inverse of the copy operation.

The masked index gather kernel properly handles the inverse transformation with consistent offset calculations and masking logic.


112-112: Static analysis false positive - lambda is appropriate here.

Similar to the copy kernel, the lambda expression for grid follows the standard Triton pattern and should not be changed.


124-148: Helper functions are well-designed with appropriate optimizations.

The SwiGLU activation, indexing, and preprocessing functions are correctly implemented with torch.compile decorators for performance optimization and NVTX ranges for profiling.


150-208: Excellent defensive programming in the DeepGemm wrapper.

The function includes comprehensive type and shape assertions, proper scaling factor transformations, and clean separation of concerns. This robust implementation should prevent runtime errors.


210-259: Clean class definition with good documentation.

The DeepGemmFusedMoE class properly inherits from CutlassFusedMoE with clear documentation of the custom operations involved.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

61-61: Address static analysis hint about lambda expression.

Static analysis suggests rewriting the lambda as a proper function definition for better readability.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

113-113: Address static analysis hint about lambda expression.

Same lambda expression style issue as in the copy function.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

289-289: Address line length issue.

The line exceeds the 120-character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b053342 and 42ba7d3.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

61-61: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


113-113: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


289-289: Line too long (131 > 120)

(E501)

🔇 Additional comments (10)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10)

1-20: LGTM - Well-organized imports for DeepGemm FP8 MoE implementation.

The import structure is clean and includes all necessary dependencies for the DeepGemm-based FP8 blockwise MoE backend.


22-71: Triton kernel implementation looks correct.

The masked index copy kernel properly handles memory access patterns with appropriate bounds checking and masking. The logic for computing token indices, offsets, and memory locations is sound.


73-122: Triton gather kernel implementation is correct.

The masked index gather kernel properly implements the inverse operation of the copy kernel with consistent memory access patterns and bounds checking.


125-149: Helper functions are well-implemented.

The fused SwiGLU activation, masked indexing, and preprocessing functions are correctly implemented with appropriate optimizations (torch.compile) and profiling annotations (NVTX ranges).


151-209: Comprehensive DeepGemm wrapper with proper validation.

The function provides excellent input validation, proper scaling factor layout transformation, and clean interface to the underlying FP8 GEMM kernel. The extensive assertions help catch configuration errors early.


211-261: Well-designed class inheriting from CutlassFusedMoE.

The class properly extends the base CutlassFusedMoE with appropriate initialization and parameter passing. The docstring clearly explains the custom ops used in the pipeline.


262-304: Proper routing and quantization setup.

The forward method correctly applies routing, validates tensor shapes and types, and sets up the appropriate quantization mode for DeepSeek FP8 block scales.


305-337: Correct allgather and permutation logic.

The distributed processing with allgather and the custom moe_permute_op integration are properly implemented with appropriate parameter passing.


338-383: Excellent FP8 GEMM pipeline implementation.

The core computation pipeline is well-structured:

  1. Proper handling of empty tensors
  2. Correct tensor padding and masking operations
  3. FP8 casting with per-token quantization
  4. Two-stage GEMM operations with fused SwiGLU activation
  5. Proper tensor gathering back to original layout

This demonstrates a sophisticated understanding of FP8 quantization and efficient MoE computation.


384-404: Proper finalization with custom op integration.

The final step correctly uses the moe_finalize_scale_op to combine results, apply scaling, and unpermute the output with all necessary parameters.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

46-69: Address lambda assignment style issue

The function is well-implemented with proper shape validation. However, consider replacing the lambda assignment with a proper function definition for better readability.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

97-121: Address lambda assignment style issue

Similar to the copy function, consider replacing the lambda assignment with a proper function definition.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

217-248: Handle line length and consider decomposition

The method has good initial validation and routing logic. However, line 244 exceeds the 120-character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 42ba7d3 and 6a1648b.

📒 Files selected for processing (3)
  • requirements.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • requirements.txt
  • tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

59-59: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


111-111: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


244-244: Line too long (131 > 120)

(E501)

🔇 Additional comments (10)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10)

1-18: LGTM - Well-organized imports

The imports are appropriate for the DeepGemm FP8 MoE implementation, with clear separation between external dependencies and internal modules.


20-44: Well-implemented Triton kernel

The masked_index_copy_kernel correctly implements masked copying with proper bounds checking and memory access patterns. The logic for computing row/column indices and offsets is sound.


71-95: Complementary gather kernel implementation

The masked_index_gather_kernel correctly implements the inverse operation of the copy kernel, with consistent bounds checking and memory access patterns.


123-147: Well-designed helper functions

The helper functions are appropriately decorated and implemented:

  • swiglu_fused_moe correctly implements the SwiGLU activation pattern
  • indexing ensures contiguity for performance
  • preprocess_after_permute efficiently computes expert mappings using torch.searchsorted

149-164: Clean wrapper for DeepGemm FP8 operations

The function provides a well-structured interface to the DeepGemm library with appropriate output tensor allocation and parameter passing.


186-216: Standard initialization pattern

The class initialization follows the expected pattern by calling the parent constructor with all necessary parameters.


249-293: Complex quantization and permutation logic

The quantization setup and moe_permute_op call are well-structured with appropriate parameter passing. The handling of different quantization modes and distributed processing is comprehensive.


294-314: Efficient memory management for padded tensors

The implementation correctly handles empty tensor cases and efficiently manages memory allocation for padded tensors. The use of Triton masked copy is appropriate for performance.


315-335: Well-orchestrated FP8 GEMM computation flow

The dual GEMM operations with intermediate SwiGLU activation correctly implement the MoE computation pattern. The FP8 casting before each GEMM operation is appropriate for the DeepGemm backend.


336-359: Proper result gathering and finalization

The output gathering using Triton and final scaling/unpermutation through moe_finalize_scale_op correctly completes the MoE computation pipeline.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (7)
tensorrt_llm/quantization/utils/fp8_utils.py (2)

104-133: Potential tensor creation inefficiency and magic numbers.

The function creates multiple temporary tensors and uses hardcoded constants that should be documented.

The function could be optimized and made more readable:

 def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
-    # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
+    """
+    Convert FP32 tensor to column-major TMA-aligned packed uint8 format.
+    
+    Note: For extreme performance, consider rewriting this function in CUDA.
+    
+    Args:
+        x: Input tensor (2D or 3D) of dtype float32
+        
+    Returns:
+        Packed tensor in column-major TMA-aligned layout
+    """
     assert x.dtype == torch.float and x.dim() in (2, 3)
 
-    # First, convert into UE8M0 `uint8_t`
+    # Convert to UE8M0 format by extracting exponent bits (bits 23-30)
     ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)

136-163: Improve error messages and add input validation.

The function has good validation logic but could provide more helpful error messages.

 def check_sf_layout(sf: torch.Tensor,
                     mn: int,
                     k: int,
                     gran: Tuple[int, int],
                     num_groups: Optional[int],
                     tma_stride_check: bool = False,
                     type_check: Optional[torch.dtype] = None) -> torch.Tensor:
+    """
+    Validate scaling factor tensor layout and constraints.
+    
+    Args:
+        sf: Scaling factor tensor to validate
+        mn: M*N dimension size
+        k: K dimension size  
+        gran: Granularity tuple (m_gran, k_gran)
+        num_groups: Number of groups (None for 2D tensors)
+        tma_stride_check: Whether to validate TMA alignment
+        type_check: Expected dtype (None to skip check)
+        
+    Returns:
+        The input tensor (for chaining)
+        
+    Raises:
+        AssertionError: If layout constraints are violated
+    """
     # Type check
     if type_check is not None:
-        assert sf.dtype == type_check
+        assert sf.dtype == type_check, f"Expected dtype {type_check}, got {sf.dtype}"
 
     # Always do shape checks
-    assert sf.dtype in (torch.float, torch.int)
-    assert sf.dim() == int(num_groups is not None) + 2
+    assert sf.dtype in (torch.float, torch.int), f"Unsupported dtype {sf.dtype}"
+    expected_dims = int(num_groups is not None) + 2
+    assert sf.dim() == expected_dims, f"Expected {expected_dims} dimensions, got {sf.dim()}"
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (5)

48-70: Fix static analysis issue: replace lambda with function definition.

The lambda assignment violates Python style guidelines.

-    # launch kernel
-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
-    masked_index_copy_kernel[grid](output,
+    # launch kernel
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    
+    masked_index_copy_kernel[grid](output,

99-122: Fix static analysis issue: replace lambda with function definition.

Same lambda assignment issue as in the copy function.

-    # launch kernel
-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
-    masked_index_gather_kernel[grid](output,
+    # launch kernel  
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    
+    masked_index_gather_kernel[grid](output,

138-148: Consider adding input validation for the mask.

The indexing function lacks validation for the mask tensor.

 @nvtx_range("[DG] indexing")
 @torch.compile(dynamic=True)
 def indexing(x, mask):
+    """Select rows from x where mask > 0."""
+    assert mask.dim() == 1, "Mask must be 1D"
+    assert mask.shape[0] == x.shape[0], "Mask length must match first dimension of x"
     return x[mask > 0, :].contiguous()

287-293: Fix line length issue and improve readability.

Line 289 exceeds the 120-character limit and the comment indicates a temporary workaround.

         if self.apply_router_weight_on_input:
-            assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert self.routing_method.top_k == 1, \
+                "Current workaround only supports top-1 routing"
+            assert x.dtype != torch.float8_e4m3fn, \
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
             x = x * token_final_scales.to(x.dtype)
             # TODO: remove this once we have correct fusedmoe kernel ready
             token_final_scales = None

295-304: Add validation for quantization configuration.

The error handling could be more informative about supported quantization modes.

         # quantize inputs
         use_deepseek_fp8_block_scale = False
         x_sf = None
         if self.has_any_quant:
             if self.has_deepseek_fp8_block_scales:
                 use_deepseek_fp8_block_scale = True
             else:
                 raise ValueError(
-                    f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
+                    f"Unsupported quantization mode for DeepGemm backend: {self.quant_config.quant_mode}. "
+                    f"Only FP8 block scaling is supported."
                 )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6a1648b and e94c5c4.

📒 Files selected for processing (3)
  • requirements.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • requirements.txt
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

61-61: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


113-113: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


289-289: Line too long (131 > 120)

(E501)

🔇 Additional comments (8)
tensorrt_llm/quantization/utils/fp8_utils.py (2)

8-19: LGTM! Clean integer math utility.

The ceiling division implementation is correct and well-documented.


22-23: LGTM! Simple alignment utility.

Correctly implements alignment using the ceiling division helper.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)

22-46: LGTM! Well-implemented Triton kernel with proper bounds checking.

The masked index copy kernel correctly handles bounds checking and memory access patterns.


73-97: LGTM! Correct implementation of masked gather kernel.

The kernel properly implements the reverse operation of the copy kernel with appropriate bounds checking.


125-135: LGTM! Efficient SwiGLU implementation.

The fused SwiGLU activation is correctly implemented with proper tensor chunking and activation.


151-208: Comprehensive GEMM wrapper with good validation.

The function has excellent input validation and error checking. The extensive assertions ensure type safety and shape compatibility.


339-341: Good early exit optimization.

Properly handles the edge case where no tokens are selected for processing.


385-403: Complex operation orchestration looks correct.

The final scaling and unpermutation operation properly handles all the tensor transformations and expert routing logic.

Comment on lines 360 to 479
act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(
permuted_data_tensor_padded)
h1 = deepgemm_fp8_group_blockwise_gemm(
a=act_input_fp8,
b=self.w3_w1_weight,
sfa=act_input_sf,
sfb=self.quant_scales[0],
masked_m=masked_m,
expected_m=expected_m,
)
h2 = swiglu_fused_moe(h1)
act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2)
h3 = deepgemm_fp8_group_blockwise_gemm(
a=act_input_fp8,
b=self.w2_weight,
sfa=act_input_sf,
sfb=self.quant_scales[1],
masked_m=masked_m,
expected_m=expected_m,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Verify FP8 quantization numerical stability.

The two consecutive FP8 quantization steps could potentially accumulate numerical errors.

The implementation quantizes activations twice (before each GEMM operation). Verify that this doesn't introduce significant numerical degradation:


🏁 Script executed:

#!/bin/bash
# Description: Look for numerical validation tests for double FP8 quantization
rg -A 10 -B 5 "per_token_cast_to_fp8_e8m0.*per_token_cast_to_fp8_e8m0" --type py

Length of output: 82


Add numerical-stability tests for consecutive FP8 quantization

We couldn’t find any existing tests that validate the error introduced by doing two back-to-back per_token_cast_to_fp8_e8m0 calls in fused_moe_deepgemm.py. Please add numerical‐validation tests to ensure the two-stage FP8 quantization doesn’t accumulate unacceptable error:

• Target file: tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
• Add tests that:

  • Quantize a known tensor once vs. twice and compare against a high-precision reference
  • Report maximum and RMS error, asserting they remain within acceptable bounds
    • Place tests under your existing test suite (e.g. tests/…) and name them to clearly reflect “double FP8 quantization” verification
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py around lines 360
to 379, add numerical-stability tests to validate the error from consecutive FP8
quantizations using per_token_cast_to_fp8_e8m0. Create tests that quantize a
known tensor once and twice, then compare both results against a high-precision
reference tensor. Measure and report maximum and RMS errors, asserting they stay
within acceptable limits. Implement these tests in the existing test suite under
the tests directory with clear names indicating they verify double FP8
quantization stability.

Comment on lines +26 to +29
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Potential numerical instability with very small values.

The function doesn't handle edge cases where x.abs() could be zero or very close to zero, which would cause issues with torch.log2().

Consider adding a small epsilon to prevent numerical issues:

-def ceil_to_ue8m0(x: torch.Tensor):
-    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
+def ceil_to_ue8m0(x: torch.Tensor):
+    return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10))))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10))))
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 26 to 27, the
function ceil_to_ue8m0 does not handle cases where x.abs() is zero or near zero,
causing torch.log2 to fail or produce unstable results. Fix this by adding a
small epsilon value to x.abs() before applying torch.log2 to ensure numerical
stability and prevent errors.

Comment on lines +52 to +79
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation and improve error handling.

The function lacks proper input validation and error handling for edge cases.

+def per_block_cast_to_fp8_e8m0(
+        x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize tensor to FP8 e4m3fn format with per-block scaling.
+    
+    Args:
+        x: Input tensor (2D or 3D), must be on CUDA device
+        
+    Returns:
+        Tuple of (quantized_tensor, scaling_factors)
+    """
+    assert x.is_cuda, "Input tensor must be on CUDA device"
+    assert x.dtype in (torch.float16, torch.bfloat16, torch.float32), "Input must be floating point"
+    assert x.dim() in (2, 3), "Input tensor must be 2D or 3D"

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 52 to 77, the
function per_block_cast_to_fp8_e8m0 lacks input validation and error handling.
Add checks to ensure the input tensor x has the expected number of dimensions
(either 2 or 3) and raise informative errors if not. Also, validate that the
tensor sizes are compatible with the padding and reshaping operations to prevent
runtime errors. Include try-except blocks or conditional checks to handle edge
cases gracefully and provide clear error messages.

Comment on lines +80 to +92
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Verify CUDA placement and add error handling.

The function forcibly moves tensors to CUDA without checking if CUDA is available, and lacks proper error handling.

 def resmooth_to_fp8_e8m0(weight: torch.Tensor,
                          sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Re-quantize weights using existing scaling factors.
+    
+    Args:
+        weight: Weight tensor to re-quantize
+        sf: Existing scaling factors
+        
+    Returns:
+        Tuple of (re-quantized_weight, new_scaling_factors)
+    """
+    if not torch.cuda.is_available():
+        raise RuntimeError("CUDA is required for FP8 operations")
     weight = weight.cuda()
     sf = sf.cuda()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Re-quantize weights using existing scaling factors.
Args:
weight: Weight tensor to re-quantize
sf: Existing scaling factors
Returns:
Tuple of (re-quantized_weight, new_scaling_factors)
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for FP8 operations")
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 80 to 90, the
function forcibly moves tensors to CUDA without verifying CUDA availability and
lacks error handling. Modify the code to first check if CUDA is available before
moving tensors to CUDA, and add try-except blocks to catch and handle potential
errors during this operation, ensuring graceful fallback or informative error
messages.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)

111-119: Same lambda-grid issue in gather helper

Apply the same def grid(meta): … pattern here for consistency and to satisfy Ruff.


357-368: Numerical-stability tests for back-to-back FP8 quantization still missing

Previous review already requested unit tests validating error accumulation across the two consecutive per_token_cast_to_fp8_e8m0 calls; no such tests are present yet.

🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)

59-67: Replace the lambda used for Triton grid sizing

Using lambda assignments trips Ruff’s E731 and slightly hurts readability/debuggability.
A tiny helper keeps the intent clear and eliminates the style violation.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

268-271: Redundant self-assignment of output_dtype

output_dtype = output_dtype is a no-op and can be dropped.

-            assert output_dtype is not None
-            output_dtype = output_dtype
+            assert output_dtype is not None

285-285: Line exceeds 120 characters

Break the long assertion for compliance with the repo’s 120-char guideline.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )

345-355: Potentially large temporary allocation

permuted_data_tensor_padded is sized (experts_per_partition, m_max, hidden_size) where m_max is ceiling-padded to 128.
For large batches & hidden sizes this can momentarily double memory usage.

If DeepGemm tolerates ragged M, consider allocating only up to masked_m.max() instead of the padded upper bound to reduce peak memory.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e94c5c4 and 83a59cd.

📒 Files selected for processing (2)
  • requirements.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • requirements.txt
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

59-59: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


111-111: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


285-285: Line too long (131 > 120)

(E501)

🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

124-132: torch.compile may silently fall back or error on unsupported envs

torch.compile(dynamic=True) is great for PT≥2.1 + supported back-ends, but on earlier versions or when Triton / float8 isn’t available it raises at import time.
Consider wrapping the decoration in a feature-check (e.g. if torch.__version__ >= "2.1" or an env flag) to preserve library importability for downstream users who only perform inference.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

420-427: Consecutive FP8 quantization numerical stability needs verification.

This issue was previously flagged but remains unaddressed. The two consecutive FP8 quantization steps (lines 420 and 421-427) could potentially accumulate numerical errors.

🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)

122-171: Fix lambda assignment style issue.

The gather kernel logic is correct, but there's a style issue that should be addressed.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
-    masked_index_gather_kernel[grid](output,
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    masked_index_gather_kernel[grid](output,

336-336: Fix line length violation.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, \
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 83a59cd and 05b9029.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

162-162: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


336-336: Line too long (131 > 120)

(E501)

🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)

1-18: Import structure looks clean and well-organized.

The imports are appropriate for a DeepGemm-based MoE implementation with FP8 quantization. The dependency on deep_gemm and fp8_utils aligns with the module's purpose.


20-75: Triton kernel implementation follows proper FP8 quantization patterns.

The kernel correctly implements masked index copying with group-wise FP8 quantization:

  • Proper handling of program IDs and memory masking
  • Correct FP8 E4M3 range using 448.0 constant
  • Race condition prevention for scaling factor storage
  • Power-of-2 scaling factor computation

77-120: Wrapper function properly validates inputs and launches the Triton kernel.

The implementation includes:

  • Comprehensive input validation with clear assertions
  • Proper grid size and warp calculation using Triton best practices
  • Correct parameter passing to the underlying kernel

174-198: Helper functions are clean and efficiently implement their specific operations.

The TorchScript compiled functions are well-designed:

  • Proper use of NVTX ranges for profiling
  • Clear, focused functionality for each helper
  • Efficient implementations using appropriate PyTorch operations

200-256: Excellent implementation with comprehensive validation.

The DeepGemm function is thoroughly implemented with:

  • Extensive input validation covering shapes, strides, and dtypes
  • Proper scaling factor transformation using fp8_utils
  • Appropriate tensor layout assertions
  • Clean integration with the deep_gemm library

258-454: Well-structured MoE implementation with comprehensive functionality.

The DeepGemmFusedMoE class effectively orchestrates:

  • Router application and input quantization
  • Permutation and distributed processing
  • FP8 group quantization with custom kernels
  • DeepGemm GEMM operations
  • SwiGLU activation and output gathering
  • Proper finalization and scaling

The implementation handles edge cases well and integrates cleanly with the existing MoE infrastructure.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
tensorrt_llm/quantization/utils/fp8_utils.py (5)

28-29: Address numerical instability in ceil_to_ue8m0.

This is similar to a past review comment but still applies. The function will fail when x.abs() contains zero values due to torch.log2(0) being undefined.


32-51: Add comprehensive documentation for per_token_cast_to_fp8_e8m0.

This addresses the past review comment about documenting the function and its requirements.


54-79: Improve input validation in per_block_cast_to_fp8_e8m0.

This relates to the past review comment about adding proper input validation.


82-92: Add CUDA availability check in resmooth_to_fp8_e8m0.

This addresses the past review comment about verifying CUDA availability.


168-218: Add comprehensive documentation for transform_sf_into_required_layout.

This addresses the past review comment about improving documentation and error handling.

🧹 Nitpick comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)

20-75: Consider adding bounds checking and documentation for the Triton kernel.

The _masked_index_copy_group_quant_fp8 kernel performs complex indexing operations without bounds checking on several computed indices (row_idx, col_idx, elem_idx). While the valid mask provides some protection, additional bounds checks could prevent potential out-of-bounds memory access.

+    """
+    Triton kernel for masked index copying with FP8 group quantization.
+    
+    Performs element-wise FP8 quantization in groups and copies data to output
+    tensor based on computed row/column indices from permutation metadata.
+    """
     # get program id and block offset
     pid = tl.program_id(0)
     block_start = pid * group_size

162-162: Replace lambda with def for better readability.

The static analysis correctly identifies that lambda expressions assigned to variables should be replaced with proper function definitions for better readability and debugging.

-    grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
+    def grid(meta):
+        return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )

336-336: Fix line length violation.

The line exceeds the 120-character limit as flagged by the linter.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )

419-434: Document the activation tensor reallocation pattern.

The code reallocates act_input_fp8 and act_input_sf tensors between the two GEMM operations with different shapes. This pattern should be documented to clarify why reallocation is necessary rather than reusing buffers.

+        # Reallocate activation tensors for second GEMM with different dimensions
+        # Shape changes from [E, M, 2*intermediate] to [E, M, intermediate] after SwiGLU
         act_input_fp8 = torch.empty(h1.shape[0],
                                     h1.shape[1],
                                     h1.shape[2] // 2,
tensorrt_llm/quantization/utils/fp8_utils.py (2)

304-309: Fix docstring formatting issues.

The static analysis correctly identifies formatting problems in the docstring.

-def silu_and_mul_masked_post_quant_fwd(
-    input: torch.Tensor,
-    output: torch.Tensor,
-    output_scale: torch.Tensor,
-    quant_group_size: int,
-    masked_m: torch.Tensor,
-    scale_ue8m0: bool = False,
-):
-    """
-    input shape [expert_num, token_num_padded, hidden_dim]
-    output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
-    output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
-    quant_group_size  int,
-    masked_m shape [expert_num],
-    """
+def silu_and_mul_masked_post_quant_fwd(
+    input: torch.Tensor,
+    output: torch.Tensor,
+    output_scale: torch.Tensor,
+    quant_group_size: int,
+    masked_m: torch.Tensor,
+    scale_ue8m0: bool = False,
+):
+    """
+    Apply SiLU activation and multiplication with masked FP8 post-quantization.
+    
+    Args:
+        input: Input tensor with shape [expert_num, token_num_padded, hidden_dim]
+        output: Output tensor with shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
+        output_scale: Scale tensor with shape [expert_num, token_num_padded, hidden_dim // 2 // 128], dtype float32
+        quant_group_size: Quantization group size (int)
+        masked_m: Mask tensor with shape [expert_num]
+        scale_ue8m0: Whether to use UE8M0 scaling
+    """

275-278: Optimize SiLU calculation in Triton kernel.

The current implementation manually computes sigmoid using 1 / (1 + tl.exp(-gate)), which is less numerically stable than using a dedicated SiLU implementation. Consider using a more stable formulation.

-        gate = tl.load(
-            input_ptr_offs + token_index * stride_input_1 + size_n,
-            mask=offs_in_d < size_n,
-            other=0.0,
-        ).to(tl.float32)
-        gate = gate / (1 + tl.exp(-gate))
+        gate = tl.load(
+            input_ptr_offs + token_index * stride_input_1 + size_n,
+            mask=offs_in_d < size_n,
+            other=0.0,
+        ).to(tl.float32)
+        # Use more numerically stable SiLU: x * sigmoid(x)
+        gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 05b9029 and a467044.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

162-162: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


336-336: Line too long (131 > 120)

(E501)

tensorrt_llm/quantization/utils/fp8_utils.py

304-309: 1 blank line required between summary line and description

(D205)


304-309: First line should end with a period, question mark, or exclamation point

Add closing punctuation

(D415)

🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

392-409: Ignore padding alignment concern—both use 128
Both the m_max = (x.shape[0] + 127)//128*128 calculation and the group_size=128 parameter employ the same 128-element alignment. There is no mismatch to document.

Likely an incorrect or invalid review comment.

Comment on lines +106 to +135
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dtype == torch.float and x.dim() in (2, 3)

# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)

# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k),
device=x.device,
dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
aligned_k // 4)

# Finally, transpose
transposed = torch.transpose(
torch.empty((b, aligned_k // 4, aligned_mn),
device=x.device,
dtype=torch.int), 1, 2)
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Verify bit manipulation safety in get_col_major_tma_aligned_packed_tensor.

The function performs direct bit manipulation on FP32 tensors (x.view(torch.int) >> 23) to extract exponent bits for UE8M0 format. This assumes IEEE 754 format and specific endianness, which may not be portable across all platforms.


🏁 Script executed:

#!/bin/bash
# Description: Check if there are any platform-specific assumptions or documentation about IEEE 754 requirements
rg -A 5 -B 5 "ieee.*754|IEEE.*754" --type py
rg -A 5 -B 5 "endian" --type py

Length of output: 78


Verify IEEE 754 and little-endian assumptions in get_col_major_tma_aligned_packed_tensor

This function relies on reinterpreting a float32 tensor as a 32-bit integer and shifting (x.view(torch.int) >> 23) to extract exponent bits. That approach assumes:

  • IEEE 754 single-precision floats
  • Little-endian byte order

I didn’t find any existing checks or documentation for these requirements in the codebase, so on big-endian or non-IEEE 754 platforms this will silently produce wrong results.

Recommendations:

  • At the top of the function, add a runtime guard:
    import sys
    assert sys.byteorder == "little", (
        "get_col_major_tma_aligned_packed_tensor requires a little-endian platform"
    )
  • In the module or function docstring, note that it only supports IEEE 754 32-bit floats and little-endian architectures.
  • If you need to support other platforms, consider adding a conversion step or alternative implementation.
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 106 to 135, the
function get_col_major_tma_aligned_packed_tensor assumes IEEE 754 32-bit floats
and little-endian byte order when reinterpreting the float tensor as int and
shifting bits. To fix this, add a runtime assertion at the start of the function
to check that sys.byteorder is "little" and raise an error if not. Also, update
the function or module docstring to clearly state that it only supports IEEE 754
single-precision floats and little-endian architectures. If broader platform
support is needed, implement alternative handling or conversion for other cases.

@lfr-0531
Copy link
Collaborator

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

408-446: Consecutive FP8 quantizations may accumulate numerical error.

The implementation performs FP8 quantization twice (lines 408-413 for first GEMM input, lines 433-438 for second GEMM input after SwiGLU). This aligns with previous review concerns about numerical stability.

Consider adding numerical validation tests to ensure the double quantization doesn't introduce unacceptable error accumulation, as suggested in previous reviews.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

340-340: Fix line length violation.

Line exceeds the 120-character limit as flagged by static analysis.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, \
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a467044 and 4e23925.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

340-340: Line too long (131 > 120)

(E501)

🔇 Additional comments (8)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)

1-18: Import structure looks well-organized.

The imports are appropriately structured for a DeepGemm FP8 MoE backend, including the necessary deep_gemm, triton, and FP8 utilities.


20-75: Triton kernel implementation follows best practices.

The _masked_index_copy_group_quant_fp8 kernel correctly implements group-wise FP8 quantization with proper bounds checking and race condition protection for scaling factor writes.


77-120: Wrapper function is well-implemented with proper validation.

The function provides comprehensive input validation and efficient kernel launch parameters with appropriate heuristics for performance optimization.


122-176: Gather kernel and wrapper are efficiently implemented.

The masked_index_gather_kernel uses block-based processing and early returns for optimal performance, while the wrapper provides appropriate validation.


178-202: Helper functions are well-designed and efficient.

The TorchScript-compiled functions provide focused functionality with appropriate dynamic compilation and profiling annotations.


204-260: Comprehensive GEMM function with thorough validation.

The deepgemm_fp8_group_blockwise_gemm function provides extensive validation and proper scaling factor transformations for reliable FP8 GEMM operations.


262-312: Class initialization properly delegates to parent.

The constructor appropriately inherits from CutlassFusedMoE and provides comprehensive documentation of the backend's custom operations.


314-471: Complex but well-structured forward pass implementation.

The forward_chunk method orchestrates a sophisticated MoE pipeline with proper handling of routing, quantization, distributed processing, and dual GEMM operations. The implementation correctly integrates all the custom operations and kernels.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

428-433: Verify numerical stability of consecutive FP8 quantization.

The implementation performs FP8 quantization twice in the forward pass - once for the input activations and again after the SwiGLU activation. This could potentially accumulate numerical errors that affect model quality.

Based on previous review feedback, this concern about consecutive FP8 quantization accumulating numerical errors remains valid. Consider adding numerical validation tests to ensure the double quantization doesn't introduce unacceptable error levels.

Also applies to: 453-458

🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

20-77: Consider numerical stability in FP8 quantization.

The Triton kernel implementation is well-structured for parallel processing, but the quantization logic on lines 68-72 could benefit from additional numerical stability measures:

 # quantization
 _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps)
 output_s = _absmax / fp8_max
-output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
+# Add small epsilon to prevent log2(0) and improve numerical stability
+output_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(output_s), eps))))

88-91: Fix assertion message for clarity.

The assertion message contains incorrect grammar.

 assert (
     input.shape[-1] % group_size == 0
-), "the last dimension of `input` cannot be divisible by `group_size`"
+), "the last dimension of `input` must be divisible by `group_size`"

360-360: Fix line length violation.

Line exceeds the 120 character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4e23925 and caf55b8.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

360-360: Line too long (131 > 120)

(E501)

🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)

1-18: LGTM! Imports are well-organized and necessary.

The imports are appropriately structured with standard library imports first, followed by third-party dependencies, and finally internal TensorRT-LLM modules.


142-171: LGTM! Well-implemented Triton gather kernel.

The kernel efficiently handles masked index gathering with proper bounds checking and block-based processing for optimal memory access patterns.


173-195: LGTM! Clean and well-validated wrapper function.

The function properly validates tensor shapes and dimensions before launching the Triton kernel.


198-222: LGTM! Well-optimized helper functions.

The functions are appropriately decorated for performance profiling and optimization with @torch.compile(dynamic=True) where beneficial.


224-279: LGTM! Comprehensive validation and proper DeepGemm integration.

The function includes thorough tensor validation and correctly transforms scaling factors for the DeepGemm backend requirements.


282-491: LGTM! Well-structured MoE implementation with comprehensive functionality.

The class successfully implements a complete DeepGemm-based FP8 MoE backend with proper distributed processing support, tensor routing, and quantization workflows.

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/debug_moe branch from caf55b8 to e766164 Compare July 24, 2025 06:00
@lfr-0531
Copy link
Collaborator

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (7)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1296-1309: Unit tests and performance benchmarks still needed for FP8 resmoothing

The FP8 resmoothing logic is correctly implemented with proper filtering and SM 100 targeting. However, as identified in previous reviews, this functionality still lacks comprehensive unit tests and performance benchmarks to validate correctness and measure loading overhead.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

428-433: Consecutive FP8 quantization needs numerical validation

The implementation performs consecutive FP8 quantizations which could accumulate numerical errors. As noted in previous reviews, this pattern needs validation to ensure acceptable precision loss.

Also applies to: 453-458

tensorrt_llm/quantization/utils/fp8_utils.py (5)

28-29: Potential numerical instability with very small values.

The function doesn't handle edge cases where x.abs() could be zero or very close to zero, which would cause issues with torch.log2().

Consider adding a small epsilon to prevent numerical issues:

-def ceil_to_ue8m0(x: torch.Tensor):
-    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
+def ceil_to_ue8m0(x: torch.Tensor):
+    return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10))))

54-79: Add input validation and improve error handling.

The function lacks proper input validation and error handling for edge cases.

+def per_block_cast_to_fp8_e8m0(
+        x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize tensor to FP8 e4m3fn format with per-block scaling.
+    
+    Args:
+        x: Input tensor (2D or 3D), must be on CUDA device
+        
+    Returns:
+        Tuple of (quantized_tensor, scaling_factors)
+    """
+    assert x.is_cuda, "Input tensor must be on CUDA device"
+    assert x.dtype in (torch.float16, torch.bfloat16, torch.float32), "Input must be floating point"
+    assert x.dim() in (2, 3), "Input tensor must be 2D or 3D"

82-92: Verify CUDA placement and add error handling.

The function forcibly moves tensors to CUDA without checking if CUDA is available, and lacks proper error handling.

 def resmooth_to_fp8_e8m0(weight: torch.Tensor,
                          sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Re-quantize weights using existing scaling factors.
+    
+    Args:
+        weight: Weight tensor to re-quantize
+        sf: Existing scaling factors
+        
+    Returns:
+        Tuple of (re-quantized_weight, new_scaling_factors)
+    """
+    if not torch.cuda.is_available():
+        raise RuntimeError("CUDA is required for FP8 operations")
     weight = weight.cuda()
     sf = sf.cuda()

106-135: Verify IEEE 754 and little-endian assumptions in get_col_major_tma_aligned_packed_tensor

This function relies on reinterpreting a float32 tensor as a 32-bit integer and shifting (x.view(torch.int) >> 23) to extract exponent bits. That approach assumes:

  • IEEE 754 single-precision floats
  • Little-endian byte order

Recommendations:

  • At the top of the function, add a runtime guard:
    import sys
    assert sys.byteorder == "little", (
        "get_col_major_tma_aligned_packed_tensor requires a little-endian platform"
    )
  • In the module or function docstring, note that it only supports IEEE 754 32-bit floats and little-endian architectures.
  • If you need to support other platforms, consider adding a conversion step or alternative implementation.

218-218: Replace assert with proper exception handling.

The function uses assert False for error handling, which should be replaced with a proper exception that provides clear, informative error messages.

-    assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
+    raise ValueError(f'Unsupported scaling factor layout: dtype={sf.dtype}, granularity={gran}')
🧹 Nitpick comments (5)
tests/unittest/_torch/modules/test_fused_moe.py (3)

385-401: Comprehensive test parameterization with appropriate hardware requirements.

The test is well-parameterized to cover various scenarios. However, note that the large parameter space (8 different sequence lengths) may result in long test execution times. Consider if all combinations are necessary or if a subset would provide adequate coverage.


447-452: Question: Redundant scale storage?

Both weight_scale_inv and weight_scale are stored with the same values. Is this intentional for compatibility with different backends, or could this be simplified?


472-500: Improve readability with comments explaining the complex logic.

The grouped GEMM function contains complex dequantization logic with hardcoded values (e.g., 128 for block size). Consider adding comments to explain:

  1. The significance of the 128 block size
  2. The dequantization process with scale factor repetition
  3. The overall grouped GEMM algorithm

This would improve maintainability and make the reference implementation easier to understand.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

360-360: Fix line length violation

Line 360 exceeds the 120 character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
tensorrt_llm/quantization/utils/fp8_utils.py (1)

305-311: Fix docstring formatting issues.

The docstring has formatting issues that violate Python documentation standards.

 def silu_and_mul_masked_post_quant_fwd(
     input: torch.Tensor,
     output: torch.Tensor,
     output_scale: torch.Tensor,
     quant_group_size: int,
     masked_m: torch.Tensor,
     scale_ue8m0: bool = False,
 ):
     """
-    input shape [expert_num, token_num_padded, hidden_dim]
-    output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
-    output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
-    quant_group_size  int,
-    masked_m shape [expert_num],
+    Fused SiLU activation, multiplication, and masked FP8 post-quantization.
+    
+    Args:
+        input: Input tensor with shape [expert_num, token_num_padded, hidden_dim]
+        output: Output tensor with shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
+        output_scale: Output scaling factors [expert_num, token_num_padded, hidden_dim // 2 // 128], dtype float32
+        quant_group_size: Quantization group size (int)
+        masked_m: Mask tensor with shape [expert_num]
+        scale_ue8m0: Whether to use UE8M0 scaling format
     """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between caf55b8 and e766164.

📒 Files selected for processing (19)
  • examples/llm-api/quickstart_advanced.py (1 hunks)
  • examples/models/core/llama/README.md (2 hunks)
  • jenkins/L0_MergeRequest.groovy (1 hunks)
  • jenkins/L0_Test.groovy (3 hunks)
  • requirements.txt (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (5 hunks)
  • tensorrt_llm/_torch/modules/attention.py (6 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (2 hunks)
  • tensorrt_llm/_torch/modules/linear.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (2 hunks)
  • tensorrt_llm/_utils.py (2 hunks)
  • tensorrt_llm/quantization/utils/__init__.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (1 hunks)
  • tests/unittest/_torch/helpers.py (2 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (3 hunks)
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • jenkins/L0_MergeRequest.groovy
  • examples/models/core/llama/README.md
🚧 Files skipped from review as they are similar to previous changes (11)
  • requirements.txt
  • tensorrt_llm/quantization/utils/init.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • examples/llm-api/quickstart_advanced.py
  • tests/unittest/_torch/helpers.py
  • tensorrt_llm/_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
  • tensorrt_llm/_torch/modules/attention.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

360-360: Line too long (131 > 120)

(E501)

tensorrt_llm/quantization/utils/fp8_utils.py

304-309: 1 blank line required between summary line and description

(D205)


304-309: First line should end with a period, question mark, or exclamation point

Add closing punctuation

(D415)

🔇 Additional comments (20)
jenkins/L0_Test.groovy (5)

1114-1114: LGTM: Script path consolidation

The move from tests/integration/defs/test_rerun.py to jenkins/test_rerun.py is a sensible reorganization that places Jenkins-specific utilities in a more logical location.


1187-1189: LGTM: Improved test result traceability

The addition of XML modification to replace generic "pytest" testsuite names with actual stage names enhances test report clarity and traceability. The error handling with || true is appropriate.


1195-1195: LGTM: Consistent script path update

The script path change is consistent with the earlier modification on line 1114, ensuring all references point to the new location.


1203-1203: LGTM: Final consistent script path update

The script path change completes the consistent update across all three usages of the test_rerun.py script in the rerunFailedTests function.


1648-1649: LGTM: Enhanced error message for better user experience

The improvement to the error message in checkStageNameSet significantly enhances usability by:

  • Sorting available stage names for easier scanning
  • Adding proper indentation for better readability
  • Including the total count of available options
  • Providing cleaner formatting overall

This will help users quickly identify valid stage names when they encounter validation errors.

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

15-15: LGTM! Improved error handling robustness.

Broadening the exception handling from ModuleNotFoundError to ImportError is a good improvement that catches a wider range of import-related failures beyond just missing modules, such as compilation errors or dependency issues that may arise with the new FP8 integration.

tests/unittest/_torch/modules/test_fused_moe.py (3)

12-13: LGTM! Necessary imports for FP8 testing.

The new imports for FP8 utility functions and DeepGemmFusedMoE are appropriate for the new test functionality.

Also applies to: 29-30


402-421: LGTM! Proper test setup with deterministic patterns.

The test setup correctly uses deterministic seeding and patterns to ensure reproducible results and avoid false positive failures.


545-550: LGTM! Proper test execution with appropriate tolerances.

The test correctly compares the DeepGemm implementation against the reference with reasonable tolerances for FP8 quantization testing.

tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)

41-41: LGTM: Necessary imports for FP8 functionality

The imports for logger and resmooth_to_fp8_e8m0 are correctly added to support the FP8 resmoothing logic introduced in the weight loading method.

Also applies to: 47-47


1212-1212: LGTM: Ensuring memory contiguity after transpose

Adding .contiguous() after transpose operations is a good practice to ensure optimal memory layout for subsequent tensor operations, particularly important for performance-critical paths.

Also applies to: 1253-1253


1362-1381: LGTM: Well-implemented conditional dequantization logic

The dequantization logic properly:

  • Checks for parameter existence before execution
  • Uses CUDA tensors for performance
  • Applies correct reshaping and dtype conversion
  • Handles both k_b_proj_trans_dequant and v_b_proj_dequant parameters consistently

This supports the alternative FP8 batched matrix multiplication path for SM 100 hardware.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)

1-18: LGTM: Comprehensive imports for DeepGemm MoE implementation

The imports are well-organized and include all necessary dependencies for the DeepGemm-based fused MoE backend functionality.


20-77: LGTM: Well-implemented Triton kernel for FP8 group quantization

The kernel correctly implements:

  • Group-wise FP8 quantization with proper scaling
  • Masked indexing with bounds checking
  • Multi-stage processing for performance optimization
  • Proper handling of tensor strides and memory access patterns

The complexity is appropriate for the performance-critical MoE operations.


79-140: LGTM: Robust Python wrapper with proper validation and tuning

The wrapper function provides:

  • Comprehensive input validation ensuring correct tensor shapes and memory layout
  • Dynamic performance tuning based on token count (different block sizes and staging)
  • Proper FP8 quantization parameter setup
  • Clean kernel launch with appropriate grid configuration

This ensures both correctness and performance optimization.


142-196: LGTM: Efficient gather kernel implementation

The masked index gather kernel and wrapper are well-implemented with:

  • Proper bounds checking and validation
  • Block-based processing for memory efficiency
  • Clean tensor shape assertions
  • Appropriate grid configuration

The implementation is straightforward and correct for the gather operation requirements.


198-222: LGTM: Well-designed helper functions with performance optimization

The helper functions are cleanly implemented:

  • swiglu_fused_moe: Correct SwiGLU activation implementation
  • indexing: Efficient masked indexing with contiguous output
  • preprocess_after_permute: Proper token-to-expert mapping computation

The use of @torch.compile(dynamic=True) and NVTX profiling annotations is appropriate for performance-critical MoE operations.


224-280: LGTM: Robust DeepGemm integration with comprehensive validation

The function provides excellent implementation with:

  • Thorough tensor shape, stride, and dtype validation
  • Proper scaling factor layout transformation using fp8_utils
  • Correct DeepGemm API usage with masked grouped GEMM
  • Clear assertions ensuring tensor memory layout requirements

The comprehensive validation ensures correctness while the DeepGemm integration provides high-performance FP8 computation.


282-331: LGTM: Well-designed class structure with informative documentation

The DeepGemmFusedMoE class is properly structured:

  • Inherits from CutlassFusedMoE to maintain consistent interface
  • Comprehensive docstring explaining the multi-op backend composition
  • Constructor correctly forwards all parameters to parent class
  • Clear parameter documentation for the MoE layer configuration

The design maintains compatibility while enabling DeepGemm-specific optimizations.


333-491: LGTM: Comprehensive MoE forward pass implementation

The forward method correctly orchestrates the complete MoE pipeline:

  • Proper routing and expert selection
  • Efficient permutation and data layout transformations
  • FP8 quantization with group-wise scaling
  • DeepGemm GEMM operations with masked computation
  • SwiGLU activation with integrated quantization
  • Final scaling and result assembly

The implementation demonstrates good understanding of the performance-critical MoE operations and integrates well with the TensorRT-LLM infrastructure.

Comment on lines +483 to +488
d = torch.empty((a.shape[0], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
for idx in range(offset_array.numel() - 1):
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix duplicate variable declarations.

The variables d and m_indices are declared twice in the grouped_gemm function. Remove the duplicate declarations.

-        num_groups, n, k_ = b.shape
-        d = torch.empty((a.shape[0], b.shape[1]),
-                        device=b.device,
-                        dtype=torch.bfloat16)
-        m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
-        for idx in range(offset_array.numel() - 1):
-            m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
+        num_groups, n, k_ = b.shape
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
d = torch.empty((a.shape[0], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
for idx in range(offset_array.numel() - 1):
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
num_groups, n, k_ = b.shape
🤖 Prompt for AI Agents
In tests/unittest/_torch/modules/test_fused_moe.py around lines 483 to 488, the
variables d and m_indices are declared twice within the grouped_gemm function.
Locate and remove the duplicate declarations of these variables to ensure each
is declared only once, preventing redundancy and potential conflicts.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

447-485: Verify numerical stability of consecutive FP8 quantizations.

The implementation performs multiple FP8 quantization operations which could potentially accumulate numerical errors. This was flagged in a previous review but tests were not added.

Add numerical stability tests to validate the error from consecutive FP8 quantizations:

#!/bin/bash
# Description: Look for existing numerical validation tests for double FP8 quantization
rg -A 10 -B 5 "numerical.*stability|double.*fp8|consecutive.*quantiz" tests/ --type py
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

383-383: Fix line length violation.

Line exceeds the 120 character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c373ca3 and 11a31d6.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

383-383: Line too long (131 > 120)

(E501)

🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)

20-86: LGTM: Well-implemented Triton kernel for FP8 quantization.

The kernel correctly implements masked index copy with group-wise FP8 quantization, including proper scaling factor computation and bitwise packing for efficiency.


88-162: LGTM: Well-designed Python wrapper with proper validation and optimization.

The function provides comprehensive input validation, workload-adaptive kernel tuning, and correct handling of tensor layouts and padding requirements.


165-194: LGTM: Clean and efficient gather kernel implementation.

The kernel correctly implements masked index gathering with proper bounds checking and block-based processing for efficiency.


196-218: LGTM: Appropriate wrapper for the gather kernel.

The function provides necessary input validation and uses a reasonable fixed block size for the gather operation.


221-244: LGTM: Well-designed helper functions with performance optimizations.

The functions are properly compiled and instrumented with NVTX ranges, implementing their respective operations efficiently.


247-302: LGTM: Robust GEMM implementation with comprehensive validation.

The function provides thorough input validation, proper scaling factor transformation, and correct integration with the deep_gemm library.

Comment on lines +1 to +18
from typing import List, Optional, Union

import deep_gemm
import torch
import torch.nn.functional as F
import triton
import triton.language as tl

import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import nvtx_range

from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .routing import BaseMoeRoutingMethod

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add required NVIDIA copyright header.

According to the coding guidelines, all TensorRT-LLM source files must contain an NVIDIA copyright header with the current year.

Add the copyright header at the top of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from typing import List, Optional, Union
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from typing import List, Optional, Union
import deep_gemm
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .routing import BaseMoeRoutingMethod
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import deep_gemm
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .routing import BaseMoeRoutingMethod
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py at the very top
before any imports, add the required NVIDIA copyright header with the current
year as per coding guidelines. This header should clearly state NVIDIA's
copyright ownership and the year to comply with project standards.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1297-1310: Add unit tests for FP8 resmoothing & benchmark loading overhead

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)

1-18: Add required NVIDIA copyright header.


441-471: Add numerical-stability tests for consecutive FP8 quantization

🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1363-1382: Consider extracting dequantization logic to reduce duplication

The dequantization logic for k_b_proj_trans and v_b_proj is duplicated. Consider extracting it into a helper function for better maintainability.

+def dequantize_projection(weight, scale, target_shape, target_dtype):
+    """Helper function to dequantize projection weights."""
+    return weight_dequant(
+        weight.view(-1, weight.shape[-1]).cuda(),
+        scale.view(-1, scale.shape[-1]).cuda(),
+    ).view(*target_shape).to(target_dtype)

 if attn_module.k_b_proj_trans_dequant is not None:
-    attn_module.k_b_proj_trans_dequant.data.copy_(
-        weight_dequant(
-            k_b_proj_trans.view(
-                -1, k_b_proj_trans.shape[-1]).cuda(),
-            k_b_proj_trans_scale.view(
-                -1,
-                k_b_proj_trans_scale.shape[-1]).cuda(),
-        ).view(
-            *attn_module.k_b_proj_trans_dequant.shape).
-        to(attn_module.k_b_proj_trans_dequant.dtype))
+    attn_module.k_b_proj_trans_dequant.data.copy_(
+        dequantize_projection(
+            k_b_proj_trans,
+            k_b_proj_trans_scale,
+            attn_module.k_b_proj_trans_dequant.shape,
+            attn_module.k_b_proj_trans_dequant.dtype
+        ))
 if attn_module.v_b_proj_dequant is not None:
-    attn_module.v_b_proj_dequant.data.copy_(
-        weight_dequant(
-            v_b_proj.view(-1,
-                          v_b_proj.shape[-1]).cuda(),
-            v_b_proj_scale.view(
-                -1, v_b_proj_scale.shape[-1]).cuda(),
-        ).view(*attn_module.v_b_proj_dequant.shape).to(
-            attn_module.v_b_proj_dequant.dtype))
+    attn_module.v_b_proj_dequant.data.copy_(
+        dequantize_projection(
+            v_b_proj,
+            v_b_proj_scale,
+            attn_module.v_b_proj_dequant.shape,
+            attn_module.v_b_proj_dequant.dtype
+        ))
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

78-79: Add comment explaining the scale factor packing logic

The bitcast and bit shift operation extracts the FP32 exponent efficiently, but this optimization deserves a clarifying comment.

-            output_s = output_s.to(tl.int32, bitcast=True) >> 23
-            output_s_int32 += output_s << (group_index * 8)
+            # Extract FP32 exponent by bitcasting to int32 and shifting right by 23
+            output_s = output_s.to(tl.int32, bitcast=True) >> 23
+            # Pack 4 group scales into a single int32 (8 bits per scale)
+            output_s_int32 += output_s << (group_index * 8)

88-95: Add comprehensive docstring for better documentation

The function lacks documentation explaining its purpose, parameters, and return value.

 def masked_index_copy_group_quant_fp8(
     output: torch.Tensor,
     input: torch.Tensor,
     start_offsets: torch.Tensor,
     row_indices: torch.Tensor,
     group_size: int,
     eps: float = 1e-10,
 ):
+    """
+    Perform masked index copy with group-wise FP8 quantization.
+    
+    Args:
+        output: Output tensor for quantized values, shape [num_experts, max_tokens, hidden_size]
+        input: Input tensor to be quantized and copied, shape [num_tokens, hidden_size]
+        start_offsets: Start token offsets for each expert, shape [num_experts + 1]
+        row_indices: Expert indices for each token, shape [num_tokens]
+        group_size: Size of quantization groups (must divide hidden_size)
+        eps: Small epsilon value for numerical stability in quantization
+        
+    Returns:
+        torch.Tensor: Scale factors for the quantized values, shape [num_experts, actual_tokens, scale_dims]
+    """

377-377: Split long line for better readability

Line exceeds the 120 character limit.

-            assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            assert x.dtype != torch.float8_e4m3fn, (
+                "Current workaround for apply_router_weight_on_input does not support fp8 input"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11a31d6 and 77e179c.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (6 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (2 hunks)
  • tensorrt_llm/_torch/modules/linear.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

377-377: Line too long (131 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

1213-1213: Good practice: Ensuring contiguous memory layout after transpose

Adding .contiguous() after transpose operations ensures optimal memory access patterns for subsequent operations, which is particularly important for CUDA kernels.

Also applies to: 1254-1254


1431-1441: FP8 scale transformation correctly implemented

The scale transformation logic properly:

  • Checks for FP8 block scales and SM 100 architecture
  • Uses the correct recipe (1, 128, 128) matching the block size
  • Creates non-trainable parameters as expected
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

247-297: Well-implemented FP8 grouped GEMM wrapper

The function provides comprehensive validation and correctly handles scale factor transformations for DeepGemm operations.

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/debug_moe branch from 2c908f3 to 3a1a0b8 Compare July 25, 2025 02:14
@lfr-0531
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12927 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12927 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9637 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator

/bot run

@lfr-0531
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13021 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13021 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9725 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13082 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13082 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9780 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13122 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13122 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9819 completed with status: 'FAILURE'

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/debug_moe branch from 972ab86 to 59b3957 Compare July 28, 2025 10:56
@lfr-0531
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13192 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13192 [ run ] completed with state FAILURE

* Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27)

This reverts commit 9107cfa.

* Remove compile for reducing warnings

Signed-off-by: Barry Kang <[email protected]>

---------

Signed-off-by: Barry Kang <[email protected]>
@lfr-0531
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13217 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13217 [ run ] completed with state FAILURE

@lfr-0531
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13254 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13254 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9898 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13268 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13268 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9907 completed with status: 'FAILURE'

@litaotju litaotju closed this Aug 1, 2025
@litaotju
Copy link
Collaborator Author

litaotju commented Aug 1, 2025

Closing, since #6486 already merged.

@lfr-0531 lfr-0531 deleted the user/fanrongl/debug_moe branch September 22, 2025 07:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants