-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-5966][feat] Helix: add full MLA support for Helix #8104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-5966][feat] Helix: add full MLA support for Helix #8104
Conversation
Signed-off-by: Matthias Jouanneaux <[email protected]>
Signed-off-by: Matthias Jouanneaux <[email protected]>
Signed-off-by: Matthias Jouanneaux <[email protected]>
Signed-off-by: Matthias Jouanneaux <[email protected]>
/bot run |
📝 WalkthroughWalkthroughAdds Helix post-processing GPU kernel and Torch op, integrates CP Helix flow in attention with post-processing and optional position offsets in MLA RoPE, refactors distributed allgather and exposes cp_allgather, tweaks GEMM runner selection and error message, updates build, and adds comprehensive unit tests. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Py as PyTorch
participant Op as trtllm::helix_post_process (Torch Op)
participant K as helixPostProcess<T> (Host)
participant GPU as helix_postprocess_kernel<T> (CUDA)
Py->>Op: helix_post_process(gathered_o, gathered_stats, scale)
Op->>Op: Validate shapes/dtypes/alignment
Op->>K: Build HelixPostProcParams<T>, launch on stream
K->>GPU: Configure grid/block, launch
GPU->>GPU: Warp-reduce corrected sums
GPU->>GPU: Accumulate per-token/head blocks
GPU-->>K: Write output [num_tokens, num_heads*kv_lora_rank]
K-->>Op: Kernel complete
Op->>Op: Optional scale multiply
Op-->>Py: Return output tensor
note over GPU,K: New Helix post-processing pathway
sequenceDiagram
autonumber
participant Attn as Attention/MLA Forward
participant Rope as applyMLARopeAndAssignQKVKernelOptContext
participant Pos as helix_position_offsets
Attn->>Rope: Launch kernel(..., helix_position_offsets)
alt offsets provided
Rope->>Pos: Read offset[global_token_idx]
Rope-->>Attn: Use offset for RoPE
else no offsets
Rope-->>Attn: Use local_token_idx for RoPE
end
note over Rope: Modified position id selection
sequenceDiagram
autonumber
participant Attn as Attention (CP Helix)
participant Dist as alltoall_helix / cp_allgather
participant Op as helix_post_process
Attn->>Dist: Exchange per-CP shard outputs/stats
Dist-->>Attn: Gathered O and stats
Attn->>Op: helix_post_process(gathered_o, gathered_stats, scale)
Op-->>Attn: Post-processed O
Attn-->>Attn: Continue projection/output mapping
note over Attn,Op: New CP Helix data exchange and post-process
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (5)
cpp/tensorrt_llm/kernels/helixKernels.h (1)
30-44
: Document new public interfaces
HelixPostProcParams
andhelixPostProcess
are new exported symbols; our header rules require Doxygen comments describing their contract. Please add//!
documentation blocks so downstream users know how to populate the params and what the launcher does. As per coding guidelinescpp/tensorrt_llm/thop/helixPostProcessOp.cpp (1)
72-77
: Ensure macro hygiene and consider using a function.The CALL_CPP_OP macro creates local variables and invokes a function, which could lead to name collisions or unexpected behavior if used multiple times. Consider converting this to a templated helper function for better type safety and to avoid potential macro pitfalls.
Consider replacing the macro with a templated function:
template<typename T> void invokeHelixPostProcess(torch::Tensor& output, torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, int cp_size, int num_tokens, int num_heads, int kv_lora_rank, cudaStream_t stream) { tensorrt_llm::kernels::HelixPostProcParams<T> params{ reinterpret_cast<T*>(output.mutable_data_ptr()), reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()), static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads), static_cast<int>(kv_lora_rank) }; tensorrt_llm::kernels::helixPostProcess(params, stream); }Then replace lines 79-90:
-#define CALL_CPP_OP(T) \ - tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()), \ - reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()), \ - static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads), \ - static_cast<int>(kv_lora_rank)}; \ - tensorrt_llm::kernels::helixPostProcess(params, stream); - if (gathered_o.scalar_type() == at::ScalarType::Half) { - CALL_CPP_OP(__half); + invokeHelixPostProcess<__half>(output, gathered_o, gathered_stats, cp_size, num_tokens, num_heads, kv_lora_rank, stream); } else if (gathered_o.scalar_type() == at::ScalarType::BFloat16) { #ifdef ENABLE_BF16 - CALL_CPP_OP(__nv_bfloat16); + invokeHelixPostProcess<__nv_bfloat16>(output, gathered_o, gathered_stats, cp_size, num_tokens, num_heads, kv_lora_rank, stream); #else TLLM_THROW("BFloat16 must be enabled to use helix_post_process with bf16 tensors."); #endif }tests/unittest/_torch/thop/parallel/test_helix_postprocess.py (1)
175-201
: Handle unused variable in alignment test correctly.The static analysis tool flags line 197's
output
variable as unused, but this is a false positive. The variable is assigned to verify that the operation succeeds without raising an error. The current pattern is acceptable, though you could make the intent clearer.Consider making the intent more explicit by assigning to
_
or adding a comment:try: - output = torch.ops.trtllm.helix_post_process( + _ = torch.ops.trtllm.helix_post_process( gathered_o, gathered_stats, 1.0) - # Should not raise an error + # Success: Should not raise an error for valid alignment except RuntimeError as e: pytest.fail(f"Should not raise error for valid alignment: {e}")tensorrt_llm/_torch/modules/attention.py (2)
823-823
: Document the TODO for CP-aware weight loading.The TODO comment on line 823 notes that weight loading needs to be CP-aware for splitting v_b_proj. This is an important future task.
The TODO at line 823 indicates that weight loading for v_b_proj needs CP awareness. This could lead to incorrect behavior if weights are not split according to cp_size.
Would you like me to open a new issue to track implementing CP-aware weight loading for v_b_proj?
1467-1469
: Unused parameters in forward_generation signature.Static analysis correctly identifies that
compressed_kv
andk_pe
parameters are unused in forward_generation. These parameters are passed for consistency with forward_context but are not used in the generation path where q_nope and q_pe are derived directly from q.Consider removing unused parameters or adding a comment explaining why they're in the signature:
def forward_generation( self, q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, + compressed_kv: torch.Tensor, # Unused: q already contains all needed information + k_pe: torch.Tensor, # Unused: q already contains all needed information position_ids: torch.Tensor,Or if the parameters are vestigial, consider removing them entirely and updating all call sites.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
cpp/tensorrt_llm/kernels/helixKernels.cu
(1 hunks)cpp/tensorrt_llm/kernels/helixKernels.h
(1 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu
(3 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
(2 hunks)cpp/tensorrt_llm/thop/CMakeLists.txt
(1 hunks)cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
(1 hunks)tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
(1 hunks)tensorrt_llm/_torch/distributed/__init__.py
(1 hunks)tensorrt_llm/_torch/distributed/ops.py
(4 hunks)tensorrt_llm/_torch/modules/attention.py
(30 hunks)tests/unittest/_torch/modules/test_mla_helix.py
(1 hunks)tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (8)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
cpp/tensorrt_llm/kernels/mlaKernels.cu
cpp/tensorrt_llm/kernels/helixKernels.h
cpp/tensorrt_llm/kernels/helixKernels.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
cpp/tensorrt_llm/kernels/mlaKernels.cu
cpp/tensorrt_llm/kernels/helixKernels.h
cpp/tensorrt_llm/kernels/helixKernels.cu
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
cpp/tensorrt_llm/kernels/mlaKernels.cu
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
tensorrt_llm/_torch/distributed/__init__.py
cpp/tensorrt_llm/kernels/helixKernels.h
tensorrt_llm/_torch/distributed/ops.py
cpp/tensorrt_llm/kernels/helixKernels.cu
tensorrt_llm/_torch/modules/attention.py
tests/unittest/_torch/modules/test_mla_helix.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}
: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
cpp/tensorrt_llm/kernels/helixKernels.h
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp
cpp/tensorrt_llm/kernels/mlaKernels.cu
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
tensorrt_llm/_torch/distributed/__init__.py
cpp/tensorrt_llm/kernels/helixKernels.h
tensorrt_llm/_torch/distributed/ops.py
cpp/tensorrt_llm/kernels/helixKernels.cu
tensorrt_llm/_torch/modules/attention.py
tests/unittest/_torch/modules/test_mla_helix.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
tensorrt_llm/_torch/distributed/__init__.py
tensorrt_llm/_torch/distributed/ops.py
tensorrt_llm/_torch/modules/attention.py
tests/unittest/_torch/modules/test_mla_helix.py
**/*.{h,hpp,hh,hxx}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.
Files:
cpp/tensorrt_llm/kernels/helixKernels.h
**/*.{h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).
Files:
cpp/tensorrt_llm/kernels/helixKernels.h
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
🧬 Code graph analysis (7)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
cpp/tensorrt_llm/kernels/mlaKernels.h (1)
helix_position_offsets
(106-107)
tensorrt_llm/_torch/distributed/__init__.py (1)
tensorrt_llm/_torch/distributed/ops.py (2)
allgather
(233-239)cp_allgather
(242-248)
cpp/tensorrt_llm/kernels/helixKernels.h (1)
cpp/tensorrt_llm/kernels/helixKernels.cu (4)
void
(40-68)void
(82-207)helixPostProcess
(210-234)helixPostProcess
(210-210)
tensorrt_llm/_torch/distributed/ops.py (2)
cpp/tensorrt_llm/thop/allgatherOp.cpp (4)
input
(108-111)input
(108-108)allgather
(122-137)allgather
(122-122)tensorrt_llm/mapping.py (6)
rank
(328-329)rank
(332-339)tp_group
(368-369)tp_rank
(342-343)cp_group
(376-377)cp_rank
(351-353)
cpp/tensorrt_llm/kernels/helixKernels.cu (1)
cpp/tensorrt_llm/common/envUtils.cpp (2)
getEnvEnablePDL
(246-261)getEnvEnablePDL
(246-246)
tensorrt_llm/_torch/modules/attention.py (5)
tensorrt_llm/_torch/attention_backend/interface.py (6)
AttentionBackend
(552-630)PositionalEmbeddingParams
(506-524)PredefinedAttentionMask
(530-539)AttentionMetadata
(40-336)forward
(591-614)num_tokens
(267-268)tensorrt_llm/_torch/attention_backend/utils.py (2)
create_attention
(27-79)get_attention_backend
(10-24)tensorrt_llm/_torch/distributed/ops.py (1)
alltoall_helix
(251-286)tensorrt_llm/mapping.py (4)
has_cp_ulysses
(410-412)rank
(328-329)rank
(332-339)cp_group
(376-377)cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (2)
helix_post_process
(27-98)helix_post_process
(27-27)
tests/unittest/_torch/modules/test_mla_helix.py (6)
tensorrt_llm/_torch/attention_backend/interface.py (9)
AttentionMetadata
(40-336)RopeParams
(350-502)seq_lens
(167-168)seq_lens
(171-192)num_contexts
(195-196)num_contexts
(199-202)create_rope_const_params
(426-502)create_cuda_graph_metadata
(275-317)from_config
(372-424)tensorrt_llm/_torch/distributed/ops.py (1)
cp_allgather
(242-248)tensorrt_llm/_torch/pyexecutor/resource_manager.py (5)
get_buffers
(693-702)shutdown
(81-82)shutdown
(368-369)shutdown
(1072-1077)shutdown
(1223-1224)tensorrt_llm/_torch/utils.py (1)
model_extra_attrs
(58-64)tensorrt_llm/_utils.py (2)
str_dtype_to_binding
(216-219)torch_dtype_to_str
(225-226)tensorrt_llm/mapping.py (4)
CpType
(21-29)Mapping
(32-519)rank
(328-329)rank
(332-339)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
518-518: Unused function argument: gathered_stats
(ARG001)
518-518: Unused function argument: scale
(ARG001)
tests/unittest/_torch/thop/parallel/test_helix_postprocess.py
197-197: Local variable output
is assigned to but never used
Remove assignment to unused variable output
(F841)
tensorrt_llm/_torch/modules/attention.py
1467-1467: Unused method argument: compressed_kv
(ARG002)
1468-1468: Unused method argument: k_pe
(ARG002)
tests/unittest/_torch/modules/test_mla_helix.py
794-794: Consider moving this statement to an else
block
(TRY300)
795-795: Do not catch blind exception: Exception
(BLE001)
798-798: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
798-798: Create your own exception
(TRY002)
798-798: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (29)
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp (1)
140-140
: Nice improvement to error diagnostics.Including the return code in the GEMM failure message makes triaging easier.
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)
517-519
: LGTM! Fake op registration is correct.The unused parameters
gathered_stats
andscale
flagged by static analysis are expected and correct for a fake op registration. Fake ops only provide shape and dtype inference for TorchScript compilation; they don't execute the actual computation. The return shape correctly drops the first dimension (cp_size) fromgathered_o
.tests/unittest/_torch/modules/test_mla_helix.py (7)
1-39
: LGTM! Imports and MPI setup are appropriate.The imports are well-organized and include all necessary dependencies for distributed MLA testing. The cloudpickle registration for MPI serialization is correctly configured to handle custom types across process boundaries.
42-146
: LGTM! Well-structured test configuration.The
Scenario
andRopeConfig
dataclasses are well-designed with appropriate defaults, frozen for immutability, and kw_only for clarity. Themax_position_embeddings
property correctly ensures sufficient capacity for all test scenarios.
148-416
: LGTM! Helper functions are well-implemented.The helper functions provide comprehensive support for the distributed test:
- KV cache and metadata setup is correctly configured for MLA with Helix parallelism
- Weight initialization uses appropriate techniques (Kaiming uniform, block scaling)
- The inverse RoPE transformation in
_make_latent_cache_gen
correctly recovers original values from embedded cache- Error reporting provides detailed diagnostics for debugging
418-604
: LGTM! Distributed execution logic is correct.The
_run_mla_distributed
function correctly orchestrates the Helix-distributed MLA execution:
- Properly distributes weights across CP ranks
- Correctly handles context and generation phases
- CUDA graph capture and replay are implemented correctly with proper warmup
- Latent cache generation for non-last ranks is appropriately handled
- Thorough validation against reference outputs with detailed error reporting
606-785
: LGTM! Multi-GPU test orchestration is correct.The
_full_test_multi_gpu
function properly orchestrates the complete test:
- Rank 0 generates reference output with single-GPU execution
- Reference output is correctly broadcast to all ranks via
cp_allgather
- Both reference and distributed paths support CUDA graph for performance measurement
- Test parameters are properly constructed and distributed
787-799
: Exception handling is acceptable for MPI context.The broad exception catch and re-raise pattern is appropriate here for distributed MPI execution where we need to capture and propagate exceptions across process boundaries. The preserved traceback helps with debugging distributed failures.
While static analysis suggests improvements, the current pattern is reasonable for this MPI testing context where exception details must cross process boundaries.
802-840
: LGTM! Test function and benchmarking script are well-structured.The pytest test function is correctly parameterized with test scenarios and validates mismatch ratios appropriately. The main block provides a useful benchmarking script for performance measurement across scenarios.
cpp/tensorrt_llm/kernels/helixKernels.cu (3)
38-68
: LGTM! Warp reduction is correctly implemented.The
warpReduceCorrectedSum
function correctly implements numerically stable warp-level reduction for softmax normalization. The SM100-specific redux instruction provides an optimized path, with a proper fallback for older architectures.
70-207
: LGTM! Kernel implementation is well-optimized.The
helix_postprocess_kernel
is well-designed with several optimization strategies:
- Warp specialization (warp 0 for correction, others for pre-loading) maximizes parallelism
- Pre-loading and pipelining reduce memory latency
- SM90+ programmatic stream serialization primitives are correctly guarded
- Memory accesses are coalesced via vectorized loads/stores
- Shared memory usage is efficient
209-243
: LGTM! Host launcher is correctly implemented.The
helixPostProcess
function properly:
- Validates alignment requirements for vectorized memory access
- Checks size constraints against kernel limits
- Configures launch parameters with correct grid/block dimensions
- Enables PDL (Programmatic Dependent Launch) based on environment variable for SM90+ optimization
- Instantiates templates for supported types (__half, __nv_bfloat16)
cpp/tensorrt_llm/thop/helixPostProcessOp.cpp (3)
37-42
: LGTM! Shape derivation from inputs.The kv_lora_rank is correctly derived from the gathered_o shape and num_heads. The validation ensures that the dimension is evenly divisible, preventing potential issues downstream.
58-63
: Good alignment checks for async operations.The 16-byte alignment requirements for gathered_o and the constraint that kv_lora_rank * sizeof(data_type) must be a multiple of 16 are correctly enforced. These checks ensure safe async memcpy operations as noted in the inline comment.
110-110
: Add namespace closing comment.As per coding guidelines, namespace closing braces must include a trailing comment with the namespace name.
Apply this diff:
-} // namespace torch_ext +} // namespace torch_extWait, the closing comment is already present. Let me re-check... Yes, line 110 already has
} // namespace torch_ext
. This is correct.tests/unittest/_torch/thop/parallel/test_helix_postprocess.py (3)
25-43
: Baseline reference implementation looks correct.The baseline function implements the expected Helix post-processing logic in PyTorch for verification. The implementation correctly:
- Computes global max and corrected statistics
- Applies scaling and exponential correction
- Performs reduction and normalization
- Handles dtype casting appropriately
214-217
: Excellent large-input test coverage.Testing with larger inputs (16 cp_size, 64 heads, 512 kv_lora_rank) for both float16 and bfloat16 helps ensure the operator performs correctly and efficiently at scale.
74-78
: Gathered_stats layout verified – no changes required. The C++ struct’s float2 holds max in the first component and sum in the second, so using indices 0 and 1 in the test is correct.tensorrt_llm/_torch/modules/attention.py (11)
209-217
: CP size properly integrated into world_size calculation.The world_size calculation now includes cp_size, and the Mapping is constructed with cp_size and cp_config. This ensures distributed operations account for context parallelism ranks.
607-614
: latent_cache_gen parameter added to MLA inplace op.The custom op signature is updated to accept latent_cache_gen, enabling generation-time control over which latent cache is used. This aligns with the TODO comments (lines 1145-1149) about using next-rank latent cache in CP Helix scenarios.
732-733
: CP Ulysses not yet supported for MLA.The early NotImplementedError when CP Ulysses is detected is appropriate. The error message is clear and informative.
746-748
: Verify head count divisibility by tp_size * cp_size.The assertion requires
self.num_heads % (tp_size * cp_size) == 0
, ensuring heads can be evenly distributed across tensor-parallel and context-parallel ranks. This is critical for correctness.
750-750
: Robust RMS norm epsilon retrieval.Using getattr with a default fallback (1e-6) ensures compatibility when rms_norm_eps is not present in the config. This is a good defensive coding practice.
832-851
: Creative mapping_o construction for CP Helix output projection.The mapping_o treats tp_size * cp_size as the effective tp_size while setting cp_size=1. This allows the o_proj to perform row-wise tensor parallelism across the combined TP and CP dimensions, which is necessary after Helix post-processing reduces across CP ranks. This is a clever approach.
1003-1049
: CP Helix post-processing integration looks correct.The _attn_forward method now:
- Allocates softmax_stats for tracking partial attention statistics
- Calls attention with helix_position_offsets (position_ids)
- Splits partial outputs and stats by cp_size
- Performs alltoall_helix to gather chunks across CP ranks
- Calls helix_post_process to merge and normalize results
This aligns with the Helix attention algorithm. The scale=1.0 parameter suggests no additional scaling is needed.
1145-1152
: TODO documents latent_cache_gen usage for CP Helix.The TODO correctly identifies that in CP Helix generation, ranks other than the last should use the latent cache from the next logical rank's first token. The latent_cache_gen parameter enables this workaround.
1194-1194
: helix_position_offsets passed when cp_size > 1.The helix_position_offsets parameter is set to position_ids when CP is enabled, allowing the attention kernel to apply position-based adjustments during generation when tokens have different positions than cached KV values.
1590-1596
: Output slicing for CP Helix compatibility.When cp_size > 1, the output is sliced to
num_heads_tp_cp * v_head_dim
to match the o_proj input expectations after post-processing. The comment clarifies this is for testing Helix parallelism compatibility.
694-694
: Verify MLA assertion constraint
Ensure enforcingnum_heads == num_key_value_heads
is intentional for MLA (i.e., that grouped-query or multi-query attention patterns are not supported); if so, update the module docstring to clarify this limitation.
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
Outdated
Show resolved
Hide resolved
PR_Github #20403 [ run ] triggered by Bot |
PR_Github #20403 [ run ] completed with state |
Signed-off-by: Matthias Jouanneaux <[email protected]>
/bot run |
PR_Github #20446 [ run ] triggered by Bot |
Signed-off-by: Matthias Jouanneaux <[email protected]>
/bot run --disable-fail-fast |
PR_Github #20452 [ run ] triggered by Bot |
PR_Github #20446 [ run ] completed with state |
/bot run |
PR_Github #20452 [ run ] completed with state |
Signed-off-by: Matthias Jouanneaux <[email protected]>
/bot run |
PR_Github #20466 [ run ] triggered by Bot |
PR_Github #20466 [ run ] completed with state |
Description
This PR adds full Helix parallelism support to the MLA attention module:
Test Coverage
tests/unittest/_torch/modules/test_mla_helix.py
: Full Helix MLA testtests/unittest/_torch/thop/parallel/test_helix_postprocess.py
: Helix post-process unit testPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
[ x ] Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit