Skip to content

Conversation

MatthiasKohl
Copy link
Collaborator

@MatthiasKohl MatthiasKohl commented Sep 30, 2025

Description

This PR adds full Helix parallelism support to the MLA attention module:

  • adding Helix post-process kernels
  • adding updates required for MLA kernels to use the RoPE position IDs in generation, when generated token is at different position than previous KV cache values
  • adds tests for post-process kernels and MLA module comparing Helix vs. no-Helix implementation

Test Coverage

  • tests/unittest/_torch/modules/test_mla_helix.py : Full Helix MLA test
  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py : Helix post-process unit test

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • [ x ] Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features
    • GPU-accelerated Helix post-processing exposed as a Torch operator.
    • Attention updated to support CP Helix across context and generation, including position offset handling.
    • New public collective for CP all-gather in distributed workflows.
  • Performance
    • Improved GEMM configuration selection for large shapes and clearer runtime error messages.
  • Tests
    • Added comprehensive unit and multi-GPU distributed tests for Helix CP attention and post-processing across dtypes, sizes, and edge cases.

Signed-off-by: Matthias Jouanneaux <[email protected]>
Signed-off-by: Matthias Jouanneaux <[email protected]>
Signed-off-by: Matthias Jouanneaux <[email protected]>
@MatthiasKohl MatthiasKohl requested review from a team as code owners September 30, 2025 17:01
@MatthiasKohl MatthiasKohl changed the title User/mjoux/helix add mla latest [TRTLLM-5966][feat] Helix: add full MLA support for Helix Sep 30, 2025
@MatthiasKohl
Copy link
Collaborator Author

/bot run

Copy link
Contributor

coderabbitai bot commented Sep 30, 2025

📝 Walkthrough

Walkthrough

Adds Helix post-processing GPU kernel and Torch op, integrates CP Helix flow in attention with post-processing and optional position offsets in MLA RoPE, refactors distributed allgather and exposes cp_allgather, tweaks GEMM runner selection and error message, updates build, and adds comprehensive unit tests.

Changes

Cohort / File(s) Summary
Helix Post-Processing Core (CUDA)
cpp/tensorrt_llm/kernels/helixKernels.cu, cpp/tensorrt_llm/kernels/helixKernels.h
Introduces templated GPU post-processing kernel and host launcher for Helix; defines params struct; adds BF16/F16 instantiations; includes alignment/size checks and warp reduction helper.
Torch Op and Build Integration (Helix Post-Process)
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp, cpp/tensorrt_llm/thop/CMakeLists.txt, tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Adds Torch CUDA op helix_post_process with validations and stream launch; registers op in Torch; includes source in th_common target; provides fake op for meta/inference.
MLA RoPE Offset Support
cpp/tensorrt_llm/kernels/mlaKernels.cu
Extends kernel and invocation to accept optional helix_position_offsets for position id selection in RoPE.
Distributed Ops Refactor and API
tensorrt_llm/_torch/distributed/ops.py, tensorrt_llm/_torch/distributed/__init__.py
Refactors internal allgather to explicit group/rank; adds public wrappers allgather and cp_allgather; updates exports.
Attention CP/Helix Integration
tensorrt_llm/_torch/modules/attention.py
Integrates CP sizing/config into attention and MLA paths; adds Helix all-to-all and post-processing usage; threads position_ids and latent_cache_gen; adjusts heads/reshapes for CP.
GEMM Runner Adjustments
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
Improves error message to include code; updates tileN selection heuristic when N>256.
Unit Tests: Helix Post-Processing
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
Adds correctness and validation tests for helix_post_process across dtypes/shapes/scales with baseline comparison and error cases.
Unit Tests: Distributed MLA Helix
tests/unittest/_torch/modules/test_mla_helix.py
Adds multi-GPU distributed tests for MLA Helix scenarios, including KV setup, RoPE config, execution, and validation against reference.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Py as PyTorch
  participant Op as trtllm::helix_post_process (Torch Op)
  participant K as helixPostProcess<T> (Host)
  participant GPU as helix_postprocess_kernel<T> (CUDA)
  Py->>Op: helix_post_process(gathered_o, gathered_stats, scale)
  Op->>Op: Validate shapes/dtypes/alignment
  Op->>K: Build HelixPostProcParams<T>, launch on stream
  K->>GPU: Configure grid/block, launch
  GPU->>GPU: Warp-reduce corrected sums
  GPU->>GPU: Accumulate per-token/head blocks
  GPU-->>K: Write output [num_tokens, num_heads*kv_lora_rank]
  K-->>Op: Kernel complete
  Op->>Op: Optional scale multiply
  Op-->>Py: Return output tensor
  note over GPU,K: New Helix post-processing pathway
Loading
sequenceDiagram
  autonumber
  participant Attn as Attention/MLA Forward
  participant Rope as applyMLARopeAndAssignQKVKernelOptContext
  participant Pos as helix_position_offsets
  Attn->>Rope: Launch kernel(..., helix_position_offsets)
  alt offsets provided
    Rope->>Pos: Read offset[global_token_idx]
    Rope-->>Attn: Use offset for RoPE
  else no offsets
    Rope-->>Attn: Use local_token_idx for RoPE
  end
  note over Rope: Modified position id selection
Loading
sequenceDiagram
  autonumber
  participant Attn as Attention (CP Helix)
  participant Dist as alltoall_helix / cp_allgather
  participant Op as helix_post_process
  Attn->>Dist: Exchange per-CP shard outputs/stats
  Dist-->>Attn: Gathered O and stats
  Attn->>Op: helix_post_process(gathered_o, gathered_stats, scale)
  Op-->>Attn: Post-processed O
  Attn-->>Attn: Continue projection/output mapping
  note over Attn,Op: New CP Helix data exchange and post-process
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.95% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly identifies the JIRA ticket TRTLLM-5966, indicates a feature addition, and summarizes the primary change of adding full MLA support for Helix, which matches the pull request’s main objective.
Description Check ✅ Passed The pull request description includes the required ## Description, ## Test Coverage, and ## PR Checklist sections, with clear explanations of the change, associated tests, and checklist items, matching the repository template structure.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (5)
cpp/tensorrt_llm/kernels/helixKernels.h (1)

30-44: Document new public interfaces

HelixPostProcParams and helixPostProcess are new exported symbols; our header rules require Doxygen comments describing their contract. Please add //! documentation blocks so downstream users know how to populate the params and what the launcher does. As per coding guidelines

cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (1)

72-77: Ensure macro hygiene and consider using a function.

The CALL_CPP_OP macro creates local variables and invokes a function, which could lead to name collisions or unexpected behavior if used multiple times. Consider converting this to a templated helper function for better type safety and to avoid potential macro pitfalls.

Consider replacing the macro with a templated function:

template<typename T>
void invokeHelixPostProcess(torch::Tensor& output, 
                            torch::Tensor const& gathered_o, 
                            torch::Tensor const& gathered_stats,
                            int cp_size, int num_tokens, int num_heads, int kv_lora_rank,
                            cudaStream_t stream) {
    tensorrt_llm::kernels::HelixPostProcParams<T> params{
        reinterpret_cast<T*>(output.mutable_data_ptr()),
        reinterpret_cast<T const*>(gathered_o.data_ptr()),
        reinterpret_cast<float2 const*>(gathered_stats.data_ptr()),
        static_cast<int>(cp_size), static_cast<int>(num_tokens),
        static_cast<int>(num_heads), static_cast<int>(kv_lora_rank)
    };
    tensorrt_llm::kernels::helixPostProcess(params, stream);
}

Then replace lines 79-90:

-#define CALL_CPP_OP(T)                                                                                                 \
-    tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),              \
-        reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()), \
-        static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads),                          \
-        static_cast<int>(kv_lora_rank)};                                                                               \
-    tensorrt_llm::kernels::helixPostProcess(params, stream);
-
     if (gathered_o.scalar_type() == at::ScalarType::Half)
     {
-        CALL_CPP_OP(__half);
+        invokeHelixPostProcess<__half>(output, gathered_o, gathered_stats, cp_size, num_tokens, num_heads, kv_lora_rank, stream);
     }
     else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
     {
 #ifdef ENABLE_BF16
-        CALL_CPP_OP(__nv_bfloat16);
+        invokeHelixPostProcess<__nv_bfloat16>(output, gathered_o, gathered_stats, cp_size, num_tokens, num_heads, kv_lora_rank, stream);
 #else
         TLLM_THROW("BFloat16 must be enabled to use helix_post_process with bf16 tensors.");
 #endif
     }
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py (1)

175-201: Handle unused variable in alignment test correctly.

The static analysis tool flags line 197's output variable as unused, but this is a false positive. The variable is assigned to verify that the operation succeeds without raising an error. The current pattern is acceptable, though you could make the intent clearer.

Consider making the intent more explicit by assigning to _ or adding a comment:

         try:
-            output = torch.ops.trtllm.helix_post_process(
+            _ = torch.ops.trtllm.helix_post_process(
                 gathered_o, gathered_stats, 1.0)
-            # Should not raise an error
+            # Success: Should not raise an error for valid alignment
         except RuntimeError as e:
             pytest.fail(f"Should not raise error for valid alignment: {e}")
tensorrt_llm/_torch/modules/attention.py (2)

823-823: Document the TODO for CP-aware weight loading.

The TODO comment on line 823 notes that weight loading needs to be CP-aware for splitting v_b_proj. This is an important future task.

The TODO at line 823 indicates that weight loading for v_b_proj needs CP awareness. This could lead to incorrect behavior if weights are not split according to cp_size.

Would you like me to open a new issue to track implementing CP-aware weight loading for v_b_proj?


1467-1469: Unused parameters in forward_generation signature.

Static analysis correctly identifies that compressed_kv and k_pe parameters are unused in forward_generation. These parameters are passed for consistency with forward_context but are not used in the generation path where q_nope and q_pe are derived directly from q.

Consider removing unused parameters or adding a comment explaining why they're in the signature:

 def forward_generation(
     self,
     q: torch.Tensor,
-    compressed_kv: torch.Tensor,
-    k_pe: torch.Tensor,
+    compressed_kv: torch.Tensor,  # Unused: q already contains all needed information
+    k_pe: torch.Tensor,  # Unused: q already contains all needed information
     position_ids: torch.Tensor,

Or if the parameters are vestigial, consider removing them entirely and updating all call sites.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1560cca and cfa8989.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/kernels/helixKernels.cu (1 hunks)
  • cpp/tensorrt_llm/kernels/helixKernels.h (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (3 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp (2 hunks)
  • cpp/tensorrt_llm/thop/CMakeLists.txt (1 hunks)
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (1 hunks)
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1 hunks)
  • tensorrt_llm/_torch/distributed/__init__.py (1 hunks)
  • tensorrt_llm/_torch/distributed/ops.py (4 hunks)
  • tensorrt_llm/_torch/modules/attention.py (30 hunks)
  • tests/unittest/_torch/modules/test_mla_helix.py (1 hunks)
  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (8)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • cpp/tensorrt_llm/kernels/helixKernels.h
  • cpp/tensorrt_llm/kernels/helixKernels.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • cpp/tensorrt_llm/kernels/helixKernels.h
  • cpp/tensorrt_llm/kernels/helixKernels.cu
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • cpp/tensorrt_llm/kernels/helixKernels.h
  • tensorrt_llm/_torch/distributed/ops.py
  • cpp/tensorrt_llm/kernels/helixKernels.cu
  • tensorrt_llm/_torch/modules/attention.py
  • tests/unittest/_torch/modules/test_mla_helix.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
  • cpp/tensorrt_llm/kernels/helixKernels.h
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • cpp/tensorrt_llm/kernels/helixKernels.h
  • tensorrt_llm/_torch/distributed/ops.py
  • cpp/tensorrt_llm/kernels/helixKernels.cu
  • tensorrt_llm/_torch/modules/attention.py
  • tests/unittest/_torch/modules/test_mla_helix.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/modules/attention.py
  • tests/unittest/_torch/modules/test_mla_helix.py
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/tensorrt_llm/kernels/helixKernels.h
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/tensorrt_llm/kernels/helixKernels.h
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
🧬 Code graph analysis (7)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
cpp/tensorrt_llm/kernels/mlaKernels.h (1)
  • helix_position_offsets (106-107)
tensorrt_llm/_torch/distributed/__init__.py (1)
tensorrt_llm/_torch/distributed/ops.py (2)
  • allgather (233-239)
  • cp_allgather (242-248)
cpp/tensorrt_llm/kernels/helixKernels.h (1)
cpp/tensorrt_llm/kernels/helixKernels.cu (4)
  • void (40-68)
  • void (82-207)
  • helixPostProcess (210-234)
  • helixPostProcess (210-210)
tensorrt_llm/_torch/distributed/ops.py (2)
cpp/tensorrt_llm/thop/allgatherOp.cpp (4)
  • input (108-111)
  • input (108-108)
  • allgather (122-137)
  • allgather (122-122)
tensorrt_llm/mapping.py (6)
  • rank (328-329)
  • rank (332-339)
  • tp_group (368-369)
  • tp_rank (342-343)
  • cp_group (376-377)
  • cp_rank (351-353)
cpp/tensorrt_llm/kernels/helixKernels.cu (1)
cpp/tensorrt_llm/common/envUtils.cpp (2)
  • getEnvEnablePDL (246-261)
  • getEnvEnablePDL (246-246)
tensorrt_llm/_torch/modules/attention.py (5)
tensorrt_llm/_torch/attention_backend/interface.py (6)
  • AttentionBackend (552-630)
  • PositionalEmbeddingParams (506-524)
  • PredefinedAttentionMask (530-539)
  • AttentionMetadata (40-336)
  • forward (591-614)
  • num_tokens (267-268)
tensorrt_llm/_torch/attention_backend/utils.py (2)
  • create_attention (27-79)
  • get_attention_backend (10-24)
tensorrt_llm/_torch/distributed/ops.py (1)
  • alltoall_helix (251-286)
tensorrt_llm/mapping.py (4)
  • has_cp_ulysses (410-412)
  • rank (328-329)
  • rank (332-339)
  • cp_group (376-377)
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (2)
  • helix_post_process (27-98)
  • helix_post_process (27-27)
tests/unittest/_torch/modules/test_mla_helix.py (6)
tensorrt_llm/_torch/attention_backend/interface.py (9)
  • AttentionMetadata (40-336)
  • RopeParams (350-502)
  • seq_lens (167-168)
  • seq_lens (171-192)
  • num_contexts (195-196)
  • num_contexts (199-202)
  • create_rope_const_params (426-502)
  • create_cuda_graph_metadata (275-317)
  • from_config (372-424)
tensorrt_llm/_torch/distributed/ops.py (1)
  • cp_allgather (242-248)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (5)
  • get_buffers (693-702)
  • shutdown (81-82)
  • shutdown (368-369)
  • shutdown (1072-1077)
  • shutdown (1223-1224)
tensorrt_llm/_torch/utils.py (1)
  • model_extra_attrs (58-64)
tensorrt_llm/_utils.py (2)
  • str_dtype_to_binding (216-219)
  • torch_dtype_to_str (225-226)
tensorrt_llm/mapping.py (4)
  • CpType (21-29)
  • Mapping (32-519)
  • rank (328-329)
  • rank (332-339)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

518-518: Unused function argument: gathered_stats

(ARG001)


518-518: Unused function argument: scale

(ARG001)

tests/unittest/_torch/thop/parallel/test_helix_postprocess.py

197-197: Local variable output is assigned to but never used

Remove assignment to unused variable output

(F841)

tensorrt_llm/_torch/modules/attention.py

1467-1467: Unused method argument: compressed_kv

(ARG002)


1468-1468: Unused method argument: k_pe

(ARG002)

tests/unittest/_torch/modules/test_mla_helix.py

794-794: Consider moving this statement to an else block

(TRY300)


795-795: Do not catch blind exception: Exception

(BLE001)


798-798: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


798-798: Create your own exception

(TRY002)


798-798: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (29)
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp (1)

140-140: Nice improvement to error diagnostics.

Including the return code in the GEMM failure message makes triaging easier.

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)

517-519: LGTM! Fake op registration is correct.

The unused parameters gathered_stats and scale flagged by static analysis are expected and correct for a fake op registration. Fake ops only provide shape and dtype inference for TorchScript compilation; they don't execute the actual computation. The return shape correctly drops the first dimension (cp_size) from gathered_o.

tests/unittest/_torch/modules/test_mla_helix.py (7)

1-39: LGTM! Imports and MPI setup are appropriate.

The imports are well-organized and include all necessary dependencies for distributed MLA testing. The cloudpickle registration for MPI serialization is correctly configured to handle custom types across process boundaries.


42-146: LGTM! Well-structured test configuration.

The Scenario and RopeConfig dataclasses are well-designed with appropriate defaults, frozen for immutability, and kw_only for clarity. The max_position_embeddings property correctly ensures sufficient capacity for all test scenarios.


148-416: LGTM! Helper functions are well-implemented.

The helper functions provide comprehensive support for the distributed test:

  • KV cache and metadata setup is correctly configured for MLA with Helix parallelism
  • Weight initialization uses appropriate techniques (Kaiming uniform, block scaling)
  • The inverse RoPE transformation in _make_latent_cache_gen correctly recovers original values from embedded cache
  • Error reporting provides detailed diagnostics for debugging

418-604: LGTM! Distributed execution logic is correct.

The _run_mla_distributed function correctly orchestrates the Helix-distributed MLA execution:

  • Properly distributes weights across CP ranks
  • Correctly handles context and generation phases
  • CUDA graph capture and replay are implemented correctly with proper warmup
  • Latent cache generation for non-last ranks is appropriately handled
  • Thorough validation against reference outputs with detailed error reporting

606-785: LGTM! Multi-GPU test orchestration is correct.

The _full_test_multi_gpu function properly orchestrates the complete test:

  • Rank 0 generates reference output with single-GPU execution
  • Reference output is correctly broadcast to all ranks via cp_allgather
  • Both reference and distributed paths support CUDA graph for performance measurement
  • Test parameters are properly constructed and distributed

787-799: Exception handling is acceptable for MPI context.

The broad exception catch and re-raise pattern is appropriate here for distributed MPI execution where we need to capture and propagate exceptions across process boundaries. The preserved traceback helps with debugging distributed failures.

While static analysis suggests improvements, the current pattern is reasonable for this MPI testing context where exception details must cross process boundaries.


802-840: LGTM! Test function and benchmarking script are well-structured.

The pytest test function is correctly parameterized with test scenarios and validates mismatch ratios appropriately. The main block provides a useful benchmarking script for performance measurement across scenarios.

cpp/tensorrt_llm/kernels/helixKernels.cu (3)

38-68: LGTM! Warp reduction is correctly implemented.

The warpReduceCorrectedSum function correctly implements numerically stable warp-level reduction for softmax normalization. The SM100-specific redux instruction provides an optimized path, with a proper fallback for older architectures.


70-207: LGTM! Kernel implementation is well-optimized.

The helix_postprocess_kernel is well-designed with several optimization strategies:

  • Warp specialization (warp 0 for correction, others for pre-loading) maximizes parallelism
  • Pre-loading and pipelining reduce memory latency
  • SM90+ programmatic stream serialization primitives are correctly guarded
  • Memory accesses are coalesced via vectorized loads/stores
  • Shared memory usage is efficient

209-243: LGTM! Host launcher is correctly implemented.

The helixPostProcess function properly:

  • Validates alignment requirements for vectorized memory access
  • Checks size constraints against kernel limits
  • Configures launch parameters with correct grid/block dimensions
  • Enables PDL (Programmatic Dependent Launch) based on environment variable for SM90+ optimization
  • Instantiates templates for supported types (__half, __nv_bfloat16)
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (3)

37-42: LGTM! Shape derivation from inputs.

The kv_lora_rank is correctly derived from the gathered_o shape and num_heads. The validation ensures that the dimension is evenly divisible, preventing potential issues downstream.


58-63: Good alignment checks for async operations.

The 16-byte alignment requirements for gathered_o and the constraint that kv_lora_rank * sizeof(data_type) must be a multiple of 16 are correctly enforced. These checks ensure safe async memcpy operations as noted in the inline comment.


110-110: Add namespace closing comment.

As per coding guidelines, namespace closing braces must include a trailing comment with the namespace name.

Apply this diff:

-} // namespace torch_ext
+} // namespace torch_ext

Wait, the closing comment is already present. Let me re-check... Yes, line 110 already has } // namespace torch_ext. This is correct.

tests/unittest/_torch/thop/parallel/test_helix_postprocess.py (3)

25-43: Baseline reference implementation looks correct.

The baseline function implements the expected Helix post-processing logic in PyTorch for verification. The implementation correctly:

  • Computes global max and corrected statistics
  • Applies scaling and exponential correction
  • Performs reduction and normalization
  • Handles dtype casting appropriately

214-217: Excellent large-input test coverage.

Testing with larger inputs (16 cp_size, 64 heads, 512 kv_lora_rank) for both float16 and bfloat16 helps ensure the operator performs correctly and efficiently at scale.


74-78: Gathered_stats layout verified – no changes required. The C++ struct’s float2 holds max in the first component and sum in the second, so using indices 0 and 1 in the test is correct.

tensorrt_llm/_torch/modules/attention.py (11)

209-217: CP size properly integrated into world_size calculation.

The world_size calculation now includes cp_size, and the Mapping is constructed with cp_size and cp_config. This ensures distributed operations account for context parallelism ranks.


607-614: latent_cache_gen parameter added to MLA inplace op.

The custom op signature is updated to accept latent_cache_gen, enabling generation-time control over which latent cache is used. This aligns with the TODO comments (lines 1145-1149) about using next-rank latent cache in CP Helix scenarios.


732-733: CP Ulysses not yet supported for MLA.

The early NotImplementedError when CP Ulysses is detected is appropriate. The error message is clear and informative.


746-748: Verify head count divisibility by tp_size * cp_size.

The assertion requires self.num_heads % (tp_size * cp_size) == 0, ensuring heads can be evenly distributed across tensor-parallel and context-parallel ranks. This is critical for correctness.


750-750: Robust RMS norm epsilon retrieval.

Using getattr with a default fallback (1e-6) ensures compatibility when rms_norm_eps is not present in the config. This is a good defensive coding practice.


832-851: Creative mapping_o construction for CP Helix output projection.

The mapping_o treats tp_size * cp_size as the effective tp_size while setting cp_size=1. This allows the o_proj to perform row-wise tensor parallelism across the combined TP and CP dimensions, which is necessary after Helix post-processing reduces across CP ranks. This is a clever approach.


1003-1049: CP Helix post-processing integration looks correct.

The _attn_forward method now:

  1. Allocates softmax_stats for tracking partial attention statistics
  2. Calls attention with helix_position_offsets (position_ids)
  3. Splits partial outputs and stats by cp_size
  4. Performs alltoall_helix to gather chunks across CP ranks
  5. Calls helix_post_process to merge and normalize results

This aligns with the Helix attention algorithm. The scale=1.0 parameter suggests no additional scaling is needed.


1145-1152: TODO documents latent_cache_gen usage for CP Helix.

The TODO correctly identifies that in CP Helix generation, ranks other than the last should use the latent cache from the next logical rank's first token. The latent_cache_gen parameter enables this workaround.


1194-1194: helix_position_offsets passed when cp_size > 1.

The helix_position_offsets parameter is set to position_ids when CP is enabled, allowing the attention kernel to apply position-based adjustments during generation when tokens have different positions than cached KV values.


1590-1596: Output slicing for CP Helix compatibility.

When cp_size > 1, the output is sliced to num_heads_tp_cp * v_head_dim to match the o_proj input expectations after post-processing. The comment clarifies this is for testing Helix parallelism compatibility.


694-694: Verify MLA assertion constraint
Ensure enforcing num_heads == num_key_value_heads is intentional for MLA (i.e., that grouped-query or multi-query attention patterns are not supported); if so, update the module docstring to clarify this limitation.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20403 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20403 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #15395 completed with status: 'FAILURE'

Signed-off-by: Matthias Jouanneaux <[email protected]>
@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20446 [ run ] triggered by Bot

Signed-off-by: Matthias Jouanneaux <[email protected]>
@MatthiasKohl
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20452 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20446 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #15413 (Blue Ocean) completed with status: ABORTED

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20452 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15419 completed with status: 'FAILURE'

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20466 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20466 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15432 completed with status: 'SUCCESS'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants