-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Hopper Fp8 context mla #7116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Yuxin <[email protected]>
📝 WalkthroughWalkthroughAllow FP8 MLA tests on SM90; generate and register new e4m3 192x128 S_q_k_v kernel variants (including bf16-output); make TMA-store enablement consider kernel output dtype; add per-tile scale handling and new packing path; adjust V-tile transpose to use DV; extend kernel-traits signatures. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as pytest
participant Selector as KernelSelector
participant Meta as CubinMeta
participant Runner as FusedMHARunnerV2
participant Kernel as FMHA Kernel
Note over Test: SM gating updated (SM90 allowed for FP8 MLA)
Test->>Selector: Request MLA kernel (192x128), dtype=e4m3, layout=SEPARATE_Q_K_V, output_dtype?
Selector->>Meta: Query entries for 192x128 S_q_k_v (e4m3, bf16 variants)
Meta-->>Selector: Return matching cubin & entrypoint
Selector->>Runner: configure launch params (TMA-store decision reads output_dtype)
Runner->>Kernel: Launch(params, launch_params, stream)
Note right of Kernel: 16-bit output path uses Acc_packer and params_scale_bmm2_ then stg store
sequenceDiagram
autonumber
participant Warp as WarpSpec
participant DMA as DMA::Device
participant GMEM as GlobalMem
Warp->>DMA: transpose_v_tile(V)
DMA->>DMA: for dgroup_idx in Kernel_traits::DV_GROUPS
DMA->>GMEM: Load V slice (source offset)
DMA->>GMEM: Store transposed V slice (dest offset uses Kernel_traits::DV)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
/bot run |
PR_Github #16022 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/kernels/fmha_v2/fmha_test.py (1)
1-1
: Add NVIDIA copyright header tocpp/kernels/fmha_v2/fmha_test.py
Per CODING_GUIDELINES.md, all Python source files must begin with the standard NVIDIA header (current year). Please prepend the following header before the first
import
infmha_test.py
(copying the exact text from existing Python files such astriton_backend/tools/tests/test_llmapi_cancel.py
):+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION & AFFILIATES nor the names +# of its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS “AS IS” AND ANY EXPRESS +# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# import subprocess
🧹 Nitpick comments (1)
cpp/kernels/fmha_v2/fmha_test.py (1)
74-76
: Skip message is misleading (permits Ada+Hopper but message says “only hopper”)Condition allows sm89 (Ada) and sm90 (Hopper), but the message says “only hopper supports fp8 fmha currently.” Update the text for clarity.
Apply:- pytest.skip("only hopper supports fp8 fmha currently.") + pytest.skip("only Ada (sm89) and Hopper (sm90) support fp8 fmha currently.")
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(5 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(2 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
(1 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
cpp/kernels/fmha_v2/fmha_test.py
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/kernels/fmha_v2/fmha_test.py
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/setup.py
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
🔇 Additional comments (15)
cpp/kernels/fmha_v2/fmha_test.py (2)
168-169
: Test enablement aligns with FP8 MLA support expansionThe change from
sm_version != 120
tosm_version not in [90, 120]
correctly extends FP8 MLA test coverage to SM90 while maintaining backward compatibility with SM120. This aligns with the broader PR objectives of enabling FP8 context MLA for Hopper (SM90).
168-169
: Enable FP8 context MLA on SM90: gating change looks correctAllowing FP8 context MLA on sm90 in addition to sm120 matches the new kernel registrations and seems correct. The updated skip message is consistent with the condition.
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)
586-593
: Template parameter expansion looks correctThe template parameter extension from
ENABLE_BMM1_SOFTCAPPING_SCALE_
to includeRETURN_SOFTMAX_STATS_
,OutputType
, and the threeSAGE_BLOCK_SIZE_*
parameters properly aligns with the base class definition, enabling the e4m3 kernel variants to support additional output types and SAGE attention configurations. Note that this was the only substantive change in the file - Line 593 simply aligns with the new parameter list.cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)
758-768
: Correct DV-based transposition for V-tile handlingThe changes correctly update the V-tile transposition logic to use
DV_GROUPS
andDV
dimensions instead of the genericD_GROUPS
andD
dimensions. This is essential for handling V-tiles that may have different dimensions from Q/K tiles, particularly in MLA architectures where V dimensions can differ.The loop now iterates over
Kernel_traits::DV_GROUPS
(line 758) and the destination offset calculation usesKernel_traits::DV
(line 767), ensuring proper memory layout for the transposed V-tile in the DV dimension space.cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (2)
1974-1977
: Verification of FMHA E4M3 192×128 entries
- The four meta-table entries for
• E4M3→E4M3 non-causal & causal
• E4M3→BF16 non-causal & causal
are present at lines 1974–1977 in fmha_cubin.h.- The corresponding
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(...)
declarations (non-causal & _output_bf16 variants) are also present at line 263 (and causal via shared launcher) .- We did not locate the actual
run_…
definitions in this header; they’re likely implemented in a .cu/.cpp file.- No
cudaFuncSetAttribute
orcuFuncSetAttribute
calls for dynamic shared memory appear here, nor any TMA-store gating logic based on output dtype within this file.Please manually verify:
- That the
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90
functions are defined and linked into the runtime dispatcher.- That the SM90 launch path sets
MaxDynamicSharedMemorySize
≥ 164 096 B for these kernels.- That TMA-store enablement is gated on the output dtype (E4M3 vs BF16) in the dispatcher logic.
263-264
: SM90 192x128 S_q_k_v externs — declarations found; implementations not located, please verifyShort: I found the two externs in the header and SM90 guards/kernel entries, but no non-header definitions were located in the repo so I could not verify parameter-name consistency or that the implementations are built for SM90.
- Declarations: cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h:263–264 (extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(...) and the output_bf16 variant).
- Kernel table references: same file, approximately lines 1974–1977 — kSM_90 entries reference those run_fmha_v2_* symbols.
- Arch guards present: #ifndef EXCLUDE_SM_90 at fmha_cubin.h lines 28, 1333, 1449.
- Definitions: repository search returned no non-extern/implementation definitions for these symbols (so parameter-name consistency could not be verified).
Please confirm one of the following:
- Add/point to the missing implementations (likely .cu/.cpp or a linked binary/cubin) and ensure they are guarded/compiled for SM90; and verify parameter names match the declaration (params, launch_params, stream).
- Or, if the implementations are intentionally provided elsewhere (submodule/binary), note that in the build so reviewers know there is no ODR/link risk.
cpp/kernels/fmha_v2/setup.py (4)
1917-1920
: LGTM!The change to use
output_dtype
when available (falling back todtype
otherwise) is a good improvement that provides flexibility for kernels with different output types.
3815-3818
: Missing output separator inInputLayout
combinations.The combination generation for
qgmma_flash_warpspec_kernels
includesInputLayout.SEPARATE_Q_K_V
in the product with other parameters. This aligns with the MLA context requirements.
3933-3971
: Good implementation of context MLA kernel generation.The loop over
output_type
options (None
,'bf16'
) properly handles the different output data types for 192x128 MLA kernels. The kv_loop_step is appropriately set to 128 to manage shared memory constraints.
6422-6424
: Verify the Deepseek MLA generation condition.The condition includes
'e4m3'
in addition to'bf16'
and'e4m3_fp32'
for MLA generation kernels. This is consistent with the FP8 support expansion, but please ensure that the'e4m3'
dtype without explicit FP32 accumulation is intended for these MLA variants.Are there any specific numerical precision requirements that should be documented for the
'e4m3'
dtype usage in MLA generation kernels?cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (5)
1226-1232
: LGTM! Clean handling of scale_bmm2 initialization.The conditional initialization properly handles both cubin and non-cubin compilation paths using preprocessor directives.
1262-1280
: Excellent refactoring of output packing logic.The new implementation using
Acc_packer<float, Output_type, Scale>
is much cleaner and more maintainable than the previous macro-based approach. The conditional Scale flag based onUNIFIED_EPILOGUE_SCALE
provides good flexibility for different scaling modes.
831-836
: Good application of the new packing API.The TMA store path correctly uses the new
Acc_packer
API with appropriate Scale flag control.
927-931
: Consistent packing implementation across all paths.All three packing paths (regular, pack_4_elts, and non-pack_4_elts) correctly use the new
Acc_packer
API with consistent Scale flag handling.Also applies to: 957-961, 986-991
658-662
: Scale_bmm2 forwarding verified as consistent across all constructorsI’ve checked all packed‐tile constructors in gmem_tile_o_packed.h (Gmem_tile_o_gmma_32bit_8bit, Gmem_tile_o_qgmma_fp32_16bits, etc.) and each uses the
#ifdef GENERATE_CUBIN
guard to choose betweenparams.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2
and plainparams.scale_bmm2
. The unpacked (float) tile in gmem_tile_o.h correctly forwardsparams.scale_bmm2
directly (no device pointer applies). No missing or divergent patterns were found—no further changes needed here.
PR_Github #16022 [ run ] completed with state |
/bot run |
PR_Github #16203 [ run ] triggered by Bot |
PR_Github #16203 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
🧬 Code graph analysis (1)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (2)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (2)
fmha
(17-243)v2
(20-198)cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h (1)
fmha
(23-720)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)
592-593
: Base alias update: parameter forwarding looks correctForwarding the newly added template parameters (RETURN_SOFTMAX_STATS_, OutputType, SAGE block sizes) into the Base alias is consistent with the upstream Kernel_traits signature. No concerns here.
Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Yuxin <[email protected]>
8cf2834
to
3f3f846
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)
462-472
: Unify MLA gating across contexts and fix typo in fmhaRunner.cppAfter inspecting the ripgrep results, we confirmed that in
fmhaRunner.cpp
:
- At lines 462–472, the code correctly uses
isHopperContextMLA
(covering SM90 context MLA) and
isHopperFP8GenerationMLA
for generation MLA- But at lines 489–497, softmax‐stats gating still uses a BF16-only predicate (
isHopperBF16ContextMLA
), leading to inconsistent support when running FP8 context MLA on SM90- There’s also a typo in the comment at line 466: “hooper” → “hopper”
Action items:
- Replace the BF16-only predicate at lines 489–497 with a unified context-MLA check that matches
isHopperContextMLA
(possibly parameterized for FP8 if a different layout is desired).- Fix the “hooper” typo in the Deepseek-V2 comment at line 466.
Proposed minimal diff:
--- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -464,7 +464,7 @@ // Even on SM90, we use ampere-style kernel, will be optimized later mLaunchParams.warp_specialization = false; mLaunchParams.useKernelWithoutAlibi = false; - // Deepseek-V2 kernel is not hooper style right now. + // Deepseek-V2 kernel is not hopper style right now. mLaunchParams.useBase2ExpTrick = false; mLaunchParams.use_tma = false; mLaunchParams.dynamic_scheduler = false; @@ -489,13 +489,18 @@ else { - bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90 - && mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128; - mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr - && mLaunchParams.flash_attention && mLaunchParams.warp_specialization - && ((!isHopperBF16ContextMLA - && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV) - || (isHopperBF16ContextMLA - && (mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V)))); + // Keep the context-MLA predicate consistent with isHopperContextMLA above. + bool const isHopperContextMLA_forStats = + isSm90 && mFixedParams.headSizeV == 128; + // If FP8-context MLA should differ, add a dataType check here; otherwise treat all context-MLA the same. + mLaunchParams.supportReturnSoftmaxStats = + (runnerParams.softmaxStatsPtr != nullptr + && mLaunchParams.flash_attention + && mLaunchParams.warp_specialization + && ((isHopperContextMLA_forStats + && mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V) + || (!isHopperContextMLA_forStats + && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV))); }Please apply these changes to ensure consistent softmax-stats support across BF16 and FP8 context MLA on SM90, and correct the typo.
🧹 Nitpick comments (1)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)
456-458
: Make the new flag const and clarify the comment scope
- The local flag is never mutated; prefer const per guidelines.
- The comment is ambiguous; explicitly call out that the context-MLA detection keys off SM90 and V=128.
Apply this diff:
- // Now we have SM90 context and FP8 generation MLA kernels - bool isHopperContextMLA = isSm90 && mFixedParams.headSizeV == 128; + // MLA kernels available on SM90 for context (V = 128) and for FP8 generation (V = 512). + bool const isHopperContextMLA = isSm90 && mFixedParams.headSizeV == 128;
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
(2 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{cpp,cxx,cc,cu}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu}
: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)
Files:
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
/bot run |
PR_Github #16316 [ run ] triggered by Bot |
PR_Github #16316 [ run ] completed with state |
@zhou-yuxin has this test already been added to the test list ? others LGTM.
|
Yes, it's in l0_h100.yml (pre-merge stage). |
/bot run |
PR_Github #16504 [ run ] triggered by Bot |
PR_Github #16504 [ run ] completed with state |
Summary by CodeRabbit
New Features
Bug Fixes
Tests