Skip to content

Conversation

zhou-yuxin
Copy link
Collaborator

@zhou-yuxin zhou-yuxin commented Aug 21, 2025

export LLM_MODELS_ROOT=/models
pytest accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

Summary by CodeRabbit

  • New Features

    • Expanded FP8/e4m3 attention support with additional 192x128 MLA variants (including separate Q/K/V) and optional BF16 outputs.
    • Memory-store enablement now considers configured output precision, widening supported scenarios.
  • Bug Fixes

    • Adjusted V-tile transposition to use the correct DV dimension.
    • More robust output packing and per-tile scaling propagation across paths.
  • Tests

    • Enabled FP8 MLA tests on SM90 (in addition to SM120).

Copy link
Contributor

coderabbitai bot commented Aug 21, 2025

📝 Walkthrough

Walkthrough

Allow FP8 MLA tests on SM90; generate and register new e4m3 192x128 S_q_k_v kernel variants (including bf16-output); make TMA-store enablement consider kernel output dtype; add per-tile scale handling and new packing path; adjust V-tile transpose to use DV; extend kernel-traits signatures.

Changes

Cohort / File(s) Summary
Test gating
cpp/kernels/fmha_v2/fmha_test.py
FP8 MLA skip condition changed to skip when SM not in [90, 120], allowing FP8 MLA on SM90 and SM120; updated skip message.
Kernel enumeration & config
cpp/kernels/fmha_v2/setup.py
Use output_dtype = kspec.output_dtype or kspec.dtype for TMA-store enablement; include InputLayout.SEPARATE_Q_K_V in input-layout combos; add 192x128 context MLA specs with output_dtype ∈ [None, 'bf16']; permit e4m3 in MLA 192x128 path; minor comment tweak.
Output packing & per-tile scaling (Hopper)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
Add uint32_t params_scale_bmm2_ to Gmem_tile_o_qgmma_fp32_16bits; replace macro packing with Acc_packer<float, Output_type, Scale>::run, gather into uint4 -> pack to uint2 -> single fmha::stg store; propagate scale through constructors.
DMA V-tile transpose
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
transpose_v_tile now iterates Kernel_traits::DV_GROUPS and uses Kernel_traits::DV in destination offset calculation (V-tile transpose adjusted to DV dimension).
Kernel trait signature & aliases
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
Extend Kernel_traits_Hopper_qgmma_e4m3_fp32 base template parameter list with RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_/K_/V_; add conditional Gmem_tile_o alias selecting 32bit_8bit vs qgmma_fp32_16bits based on OutputType.
Cubin declarations & meta
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
Add extern kernel entrypoints for e4m3 64x128 S_q_k_v 192x128 (standard and bf16-output) and register four FusedMultiHeadAttentionKernelMetaInfoV2 entries including causal/bf16 variants.
Runner gating
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
Replace BF16-context MLA check with broader isHopperContextMLA = isSm90 && headSizeV == 128; change gating logic to use isHopperContextMLA alongside isHopperFP8GenerationMLA.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Test as pytest
    participant Selector as KernelSelector
    participant Meta as CubinMeta
    participant Runner as FusedMHARunnerV2
    participant Kernel as FMHA Kernel

    Note over Test: SM gating updated (SM90 allowed for FP8 MLA)
    Test->>Selector: Request MLA kernel (192x128), dtype=e4m3, layout=SEPARATE_Q_K_V, output_dtype?
    Selector->>Meta: Query entries for 192x128 S_q_k_v (e4m3, bf16 variants)
    Meta-->>Selector: Return matching cubin & entrypoint
    Selector->>Runner: configure launch params (TMA-store decision reads output_dtype)
    Runner->>Kernel: Launch(params, launch_params, stream)
    Note right of Kernel: 16-bit output path uses Acc_packer and params_scale_bmm2_ then stg store
Loading
sequenceDiagram
    autonumber
    participant Warp as WarpSpec
    participant DMA as DMA::Device
    participant GMEM as GlobalMem

    Warp->>DMA: transpose_v_tile(V)
    DMA->>DMA: for dgroup_idx in Kernel_traits::DV_GROUPS
    DMA->>GMEM: Load V slice (source offset)
    DMA->>GMEM: Store transposed V slice (dest offset uses Kernel_traits::DV)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • chzblych
  • QiJune
  • Wanli-Jiang
  • yuxianq

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@zhou-yuxin
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16022 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/kernels/fmha_v2/fmha_test.py (1)

1-1: Add NVIDIA copyright header to cpp/kernels/fmha_v2/fmha_test.py

Per CODING_GUIDELINES.md, all Python source files must begin with the standard NVIDIA header (current year). Please prepend the following header before the first import in fmha_test.py (copying the exact text from existing Python files such as triton_backend/tools/tests/test_llmapi_cancel.py):

+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#   * Redistributions of source code must retain the above copyright notice,
+#     this list of conditions and the following disclaimer.
+#   * Redistributions in binary form must reproduce the above copyright
+#     notice, this list of conditions and the following disclaimer in the
+#     documentation and/or other materials provided with the distribution.
+#   * Neither the name of the NVIDIA CORPORATION & AFFILIATES nor the names
+#     of its contributors may be used to endorse or promote products derived
+#     from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS “AS IS” AND ANY EXPRESS
+# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
 import subprocess
🧹 Nitpick comments (1)
cpp/kernels/fmha_v2/fmha_test.py (1)

74-76: Skip message is misleading (permits Ada+Hopper but message says “only hopper”)

Condition allows sm89 (Ada) and sm90 (Hopper), but the message says “only hopper supports fp8 fmha currently.” Update the text for clarity.
Apply:

-        pytest.skip("only hopper supports fp8 fmha currently.")
+        pytest.skip("only Ada (sm89) and Hopper (sm90) support fp8 fmha currently.")
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2d40e87 and fa8ee2f.

📒 Files selected for processing (6)
  • cpp/kernels/fmha_v2/fmha_test.py (1 hunks)
  • cpp/kernels/fmha_v2/setup.py (5 hunks)
  • cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (3 hunks)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (2 hunks)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (2 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/kernels/fmha_v2/setup.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
  • cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
  • cpp/kernels/fmha_v2/setup.py
  • cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
  • cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
  • cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
  • cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
  • cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
  • cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
  • cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
🔇 Additional comments (15)
cpp/kernels/fmha_v2/fmha_test.py (2)

168-169: Test enablement aligns with FP8 MLA support expansion

The change from sm_version != 120 to sm_version not in [90, 120] correctly extends FP8 MLA test coverage to SM90 while maintaining backward compatibility with SM120. This aligns with the broader PR objectives of enabling FP8 context MLA for Hopper (SM90).


168-169: Enable FP8 context MLA on SM90: gating change looks correct

Allowing FP8 context MLA on sm90 in addition to sm120 matches the new kernel registrations and seems correct. The updated skip message is consistent with the condition.

cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)

586-593: Template parameter expansion looks correct

The template parameter extension from ENABLE_BMM1_SOFTCAPPING_SCALE_ to include RETURN_SOFTMAX_STATS_, OutputType, and the three SAGE_BLOCK_SIZE_* parameters properly aligns with the base class definition, enabling the e4m3 kernel variants to support additional output types and SAGE attention configurations. Note that this was the only substantive change in the file - Line 593 simply aligns with the new parameter list.

cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)

758-768: Correct DV-based transposition for V-tile handling

The changes correctly update the V-tile transposition logic to use DV_GROUPS and DV dimensions instead of the generic D_GROUPS and D dimensions. This is essential for handling V-tiles that may have different dimensions from Q/K tiles, particularly in MLA architectures where V dimensions can differ.

The loop now iterates over Kernel_traits::DV_GROUPS (line 758) and the destination offset calculation uses Kernel_traits::DV (line 767), ensuring proper memory layout for the transposed V-tile in the DV dimension space.

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (2)

1974-1977: Verification of FMHA E4M3 192×128 entries

  • The four meta-table entries for
    • E4M3→E4M3 non-causal & causal
    • E4M3→BF16 non-causal & causal
    are present at lines 1974–1977 in fmha_cubin.h.
  • The corresponding extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(...) declarations (non-causal & _output_bf16 variants) are also present at line 263 (and causal via shared launcher) .
  • We did not locate the actual run_… definitions in this header; they’re likely implemented in a .cu/.cpp file.
  • No cudaFuncSetAttribute or cuFuncSetAttribute calls for dynamic shared memory appear here, nor any TMA-store gating logic based on output dtype within this file.

Please manually verify:

  • That the run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90 functions are defined and linked into the runtime dispatcher.
  • That the SM90 launch path sets MaxDynamicSharedMemorySize ≥ 164 096 B for these kernels.
  • That TMA-store enablement is gated on the output dtype (E4M3 vs BF16) in the dispatcher logic.

263-264: SM90 192x128 S_q_k_v externs — declarations found; implementations not located, please verify

Short: I found the two externs in the header and SM90 guards/kernel entries, but no non-header definitions were located in the repo so I could not verify parameter-name consistency or that the implementations are built for SM90.

  • Declarations: cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h:263–264 (extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(...) and the output_bf16 variant).
  • Kernel table references: same file, approximately lines 1974–1977 — kSM_90 entries reference those run_fmha_v2_* symbols.
  • Arch guards present: #ifndef EXCLUDE_SM_90 at fmha_cubin.h lines 28, 1333, 1449.
  • Definitions: repository search returned no non-extern/implementation definitions for these symbols (so parameter-name consistency could not be verified).

Please confirm one of the following:

  • Add/point to the missing implementations (likely .cu/.cpp or a linked binary/cubin) and ensure they are guarded/compiled for SM90; and verify parameter names match the declaration (params, launch_params, stream).
  • Or, if the implementations are intentionally provided elsewhere (submodule/binary), note that in the build so reviewers know there is no ODR/link risk.
cpp/kernels/fmha_v2/setup.py (4)

1917-1920: LGTM!

The change to use output_dtype when available (falling back to dtype otherwise) is a good improvement that provides flexibility for kernels with different output types.


3815-3818: Missing output separator in InputLayout combinations.

The combination generation for qgmma_flash_warpspec_kernels includes InputLayout.SEPARATE_Q_K_V in the product with other parameters. This aligns with the MLA context requirements.


3933-3971: Good implementation of context MLA kernel generation.

The loop over output_type options (None, 'bf16') properly handles the different output data types for 192x128 MLA kernels. The kv_loop_step is appropriately set to 128 to manage shared memory constraints.


6422-6424: Verify the Deepseek MLA generation condition.

The condition includes 'e4m3' in addition to 'bf16' and 'e4m3_fp32' for MLA generation kernels. This is consistent with the FP8 support expansion, but please ensure that the 'e4m3' dtype without explicit FP32 accumulation is intended for these MLA variants.

Are there any specific numerical precision requirements that should be documented for the 'e4m3' dtype usage in MLA generation kernels?

cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (5)

1226-1232: LGTM! Clean handling of scale_bmm2 initialization.

The conditional initialization properly handles both cubin and non-cubin compilation paths using preprocessor directives.


1262-1280: Excellent refactoring of output packing logic.

The new implementation using Acc_packer<float, Output_type, Scale> is much cleaner and more maintainable than the previous macro-based approach. The conditional Scale flag based on UNIFIED_EPILOGUE_SCALE provides good flexibility for different scaling modes.


831-836: Good application of the new packing API.

The TMA store path correctly uses the new Acc_packer API with appropriate Scale flag control.


927-931: Consistent packing implementation across all paths.

All three packing paths (regular, pack_4_elts, and non-pack_4_elts) correctly use the new Acc_packer API with consistent Scale flag handling.

Also applies to: 957-961, 986-991


658-662: Scale_bmm2 forwarding verified as consistent across all constructors

I’ve checked all packed‐tile constructors in gmem_tile_o_packed.h (Gmem_tile_o_gmma_32bit_8bit, Gmem_tile_o_qgmma_fp32_16bits, etc.) and each uses the #ifdef GENERATE_CUBIN guard to choose between params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2 and plain params.scale_bmm2. The unpacked (float) tile in gmem_tile_o.h correctly forwards params.scale_bmm2 directly (no device pointer applies). No missing or divergent patterns were found—no further changes needed here.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16022 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12044 completed with status: 'FAILURE'

@zhou-yuxin
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16203 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16203 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12186 completed with status: 'FAILURE'

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between fa8ee2f and 8cf2834.

📒 Files selected for processing (1)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
🧬 Code graph analysis (1)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (2)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (2)
  • fmha (17-243)
  • v2 (20-198)
cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)
  • fmha (23-720)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)

592-593: Base alias update: parameter forwarding looks correct

Forwarding the newly added template parameters (RETURN_SOFTMAX_STATS_, OutputType, SAGE block sizes) into the Base alias is consistent with the upstream Kernel_traits signature. No concerns here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)

462-472: Unify MLA gating across contexts and fix typo in fmhaRunner.cpp

After inspecting the ripgrep results, we confirmed that in fmhaRunner.cpp:

  • At lines 462–472, the code correctly uses
    isHopperContextMLA (covering SM90 context MLA) and
    isHopperFP8GenerationMLA for generation MLA
  • But at lines 489–497, softmax‐stats gating still uses a BF16-only predicate (isHopperBF16ContextMLA), leading to inconsistent support when running FP8 context MLA on SM90
  • There’s also a typo in the comment at line 466: “hooper” → “hopper”

Action items:

  • Replace the BF16-only predicate at lines 489–497 with a unified context-MLA check that matches isHopperContextMLA (possibly parameterized for FP8 if a different layout is desired).
  • Fix the “hooper” typo in the Deepseek-V2 comment at line 466.

Proposed minimal diff:

--- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
+++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
@@ -464,7 +464,7 @@
             // Even on SM90, we use ampere-style kernel, will be optimized later
             mLaunchParams.warp_specialization = false;
             mLaunchParams.useKernelWithoutAlibi = false;
-            // Deepseek-V2 kernel is not hooper style right now.
+            // Deepseek-V2 kernel is not hopper style right now.
             mLaunchParams.useBase2ExpTrick = false;
             mLaunchParams.use_tma = false;
             mLaunchParams.dynamic_scheduler = false;
@@ -489,13 +489,18 @@
     else
     {
-        bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90
-            && mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128;
-        mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr
-            && mLaunchParams.flash_attention && mLaunchParams.warp_specialization
-            && ((!isHopperBF16ContextMLA
-                    && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
-                || (isHopperBF16ContextMLA
-                    && (mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V))));
+        // Keep the context-MLA predicate consistent with isHopperContextMLA above.
+        bool const isHopperContextMLA_forStats =
+            isSm90 && mFixedParams.headSizeV == 128;
+        // If FP8-context MLA should differ, add a dataType check here; otherwise treat all context-MLA the same.
+        mLaunchParams.supportReturnSoftmaxStats =
+            (runnerParams.softmaxStatsPtr != nullptr
+             && mLaunchParams.flash_attention
+             && mLaunchParams.warp_specialization
+             && ((isHopperContextMLA_forStats
+                     && mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V)
+                 || (!isHopperContextMLA_forStats
+                     && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV)));
     }

Please apply these changes to ensure consistent softmax-stats support across BF16 and FP8 context MLA on SM90, and correct the typo.

🧹 Nitpick comments (1)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)

456-458: Make the new flag const and clarify the comment scope

  • The local flag is never mutated; prefer const per guidelines.
  • The comment is ambiguous; explicitly call out that the context-MLA detection keys off SM90 and V=128.

Apply this diff:

-        // Now we have SM90 context and FP8 generation MLA kernels
-        bool isHopperContextMLA = isSm90 && mFixedParams.headSizeV == 128;
+        // MLA kernels available on SM90 for context (V = 128) and for FP8 generation (V = 512).
+        bool const isHopperContextMLA = isSm90 && mFixedParams.headSizeV == 128;
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8cf2834 and 3f3f846.

📒 Files selected for processing (2)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (2 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@zhou-yuxin
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16316 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16316 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12268 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@zhou-yuxin zhou-yuxin requested a review from zhhuang-nv August 25, 2025 01:53
@PerkzZheng
Copy link
Collaborator

@zhou-yuxin has this test already been added to the test list ? others LGTM.

pytest accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

@zhhuang-nv
Copy link
Collaborator

@zhou-yuxin has this test already been added to the test list ? others LGTM.

pytest accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

Yes, it's in l0_h100.yml (pre-merge stage).

@zhou-yuxin zhou-yuxin removed the request for review from qsang-nv August 25, 2025 14:19
@zhou-yuxin
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16504 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16504 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12397 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@PerkzZheng PerkzZheng merged commit f01101f into NVIDIA:main Aug 26, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants