-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Draft:FP8 R1 #6100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft:FP8 R1 #6100
Conversation
WalkthroughThe changes introduce a new DeepGemm-based FP8 blockwise Mixture-of-Experts (MoE) backend, including its implementation, quantization utilities, and integration into the MoE factory. FP8 e8m0 quantization and resmoothing functions are added, and modules are updated to support the new backend and quantization path. Comprehensive unit tests for DeepGemm FP8 workflows are included. Additional updates improve FP8 weight handling, add MPI synchronization after prefetching, and extend CLI options. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant ModelConfig
participant MoEFactory
participant DeepGemmFusedMoE
participant QuantUtils
participant CUDA
User->>ModelConfig: Set moe_backend='DEEPGEMM'
ModelConfig->>MoEFactory: create_moe()
MoEFactory->>DeepGemmFusedMoE: Instantiate with config
DeepGemmFusedMoE->>QuantUtils: per_token_cast_to_fp8_e8m0(input)
DeepGemmFusedMoE->>QuantUtils: per_block_cast_to_fp8_e8m0(weights)
DeepGemmFusedMoE->>CUDA: deep_gemm.fp8_gemm_nt(...)
DeepGemmFusedMoE-->>User: Return output tensor
Estimated code review effort
Suggested reviewers
Poem
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (6)
tensorrt_llm/quantization/utils/fp8_utils.py (2)
22-23: Add documentation for consistency.Other functions in this module have docstrings. Consider adding documentation for the
alignfunction.def align(x: int, y: int) -> int: + """ + Align x to the nearest multiple of y. + + Args: + x: The value to align. + y: The alignment boundary. + + Returns: + The smallest multiple of y that is >= x. + """ return ceil_div(x, y) * y
26-28: Add documentation for the e8m0 utility function.This function is a key utility for e8m0 quantization. Consider adding documentation to explain its purpose.
def ceil_to_ue8m0(x: torch.Tensor): + """ + Compute the ceiling to the nearest power of 2 for tensor elements. + + This is used for e8m0 quantization where scale factors must be powers of 2. + + Args: + x: Input tensor. + + Returns: + Tensor with each element rounded up to the nearest power of 2. + """ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))tensorrt_llm/_torch/modules/attention.py (1)
710-726: Consider memory optimization for dequantized parameters.Creating separate dequantized parameters doubles memory usage for
k_b_proj_transandv_b_proj. For large models, this could be significant. Consider:
- Dequantizing on-the-fly during forward pass instead of storing dequantized copies
- Using lazy initialization to create these tensors only when needed
- Adding a comment explaining the memory trade-off
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)
30-74: Add documentation for the complex padding logic.The
copy_afterfunction implements intricate token padding and rearrangement logic. Please add a docstring explaining:
- The purpose of padding to 128-token boundaries
- The meaning of input/output tensors
- The algorithm's key steps
This will improve maintainability and help future developers understand the optimization strategy.
76-91: Consider making output dtype configurable.The function hardcodes the output dtype as
torch.bfloat16. This might limit flexibility for different precision requirements. Consider:
- Adding an optional
dtypeparameter- Or documenting why bfloat16 is specifically required
def deepgemm_fp8_group_blockwise_gemm_ref( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, m_indices: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: d = torch.empty((a.shape[0], b.shape[1]), device=b.device, - dtype=torch.bfloat16) + dtype=dtype) deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d, m_indices) return d
174-175: Fix line length violation.Line 175 exceeds the 120-character limit. Please split it across multiple lines.
- assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" + assert self.routing_method.top_k == 1, \ + "Current workaround only supports top-1 routing"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
examples/llm-api/quickstart_advanced.py(1 hunks)requirements.txt(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(5 hunks)tensorrt_llm/_torch/modules/attention.py(6 hunks)tensorrt_llm/_torch/modules/fused_moe/create_moe.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/_torch/modules/linear.py(3 hunks)tensorrt_llm/quantization/utils/__init__.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)tests/unittest/_torch/helpers.py(2 hunks)tests/unittest/_torch/modules/test_fused_moe.py(3 hunks)tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)
tests/unittest/_torch/helpers.py (4)
calc_diff(71-75)per_block_cast_to_fp8(28-41)per_block_cast_to_fp8_e8m0(55-68)per_token_cast_to_fp8_e8m0(44-52)tensorrt_llm/quantization/utils/fp8_utils.py (2)
per_block_cast_to_fp8_e8m0(42-67)per_token_cast_to_fp8_e8m0(31-39)
tensorrt_llm/_torch/modules/linear.py (2)
tensorrt_llm/quantization/utils/fp8_utils.py (1)
per_token_cast_to_fp8_e8m0(31-39)tensorrt_llm/_utils.py (3)
get_sm_version(648-650)shape(905-906)shape(922-923)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (5)
tensorrt_llm/quantization/utils/fp8_utils.py (1)
resmooth_to_fp8_e8m0(70-80)tensorrt_llm/models/modeling_utils.py (2)
layer_quant_mode(161-165)layer_quant_mode(301-307)tensorrt_llm/_torch/modules/linear.py (1)
has_fp8_block_scales(1147-1150)tests/integration/defs/conftest.py (1)
get_sm_version(1857-1860)tensorrt_llm/_utils.py (3)
get_sm_version(648-650)shape(905-906)shape(922-923)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
175-175: Line too long (131 > 120)
(E501)
🔇 Additional comments (28)
tensorrt_llm/quantization/utils/__init__.py (1)
1-3: LGTM: Proper module exposure following existing patterns.The addition of
fp8_utilsto the imports and__all__list follows the established pattern and properly exposes the new FP8 quantization utilities for public use.examples/llm-api/quickstart_advanced.py (1)
50-54: LGTM: Proper CLI argument expansion for new MoE backends.The addition of
'DEEPGEMM'and'CUTEDSL'to the--moe_backendchoices correctly exposes the new MoE backend implementations while maintaining backward compatibility.tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)
27-61: LGTM: Comprehensive test for DeepGemm FP8 GEMM implementation.The test function properly:
- Targets the correct hardware (SM 100/Blackwell) with appropriate skip conditions
- Tests multiple matrix dimensions with comprehensive parameterization
- Uses the new FP8 casting utilities correctly
- Compares against a reference implementation with appropriate tolerance
The 1e-2 tolerance threshold is reasonable for FP8 precision testing.
21-23: Helper imports are availableI confirmed that
tests/unittest/_torch/helpers.pydefines all imported functions—calc_diff,per_block_cast_to_fp8,per_block_cast_to_fp8_e8m0, andper_token_cast_to_fp8_e8m0—so thefrom _torch.helpersstatements intest_fp8_block_scale_gemm.pyare valid. No changes are needed.tensorrt_llm/_torch/modules/linear.py (4)
22-24: LGTM: Proper imports for FP8 utilities and SM version detection.The imports are correctly added to support the new FP8 casting functionality and hardware-specific optimizations.
546-559: LGTM: Well-implemented hardware-specific optimization for SM 100.The conditional branch properly:
- Checks for SM 100 hardware before using DeepGemm
- Imports
deep_gemmonly when needed to avoid import errors on other hardware- Uses the correct FP8 casting function for the DeepGemm path
- Maintains the same output tensor structure as the original implementation
The fallback to the existing
torch.ops.trtllmpath ensures backward compatibility.
597-598: No dependency on weight vs. weight_scale loading order detected
After searching through allload_weights_*implementations, there’s no logic that reads one before the other—these two calls simply assign independent buffers. The reordering inload_weights_fused_qkv_linearwon’t affect any downstream functionality.
612-613: No order-dependent logic found in FUSED_GATE_UP_LINEAR loaderI searched for any references to the order of fused weight vs. fused scale loading and confirmed that in tensorrt_llm/_torch/modules/linear.py the method
load_weights_fused_gate_up_linear always does:copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale)– exactly mirroring the FUSED_QKV_LINEAR implementation. There are no other code paths or initialization routines that assume a different sequence. This change should not break existing functionality.
tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3)
11-11: LGTM!The import follows the established pattern for MoE backend modules.
35-36: LGTM!The DEEPGEMM backend integration follows the existing pattern for MoE backend selection.
145-158: LGTM!The DeepGemmFusedMoE instantiation correctly follows the pattern used by CutlassFusedMoE and other similar backends, passing all required parameters.
tests/unittest/_torch/helpers.py (4)
11-13: LGTM!The align function correctly implements alignment to the nearest multiple using ceiling division.
15-17: LGTM!The function correctly computes the ceiling to the nearest power of 2, which is appropriate for e8m0 scaling factor computation.
44-53: LGTM!The per-token FP8 e8m0 casting correctly implements power-of-2 scaling factors, which is the key difference from standard FP8 quantization.
55-69: LGTM!The per-block FP8 e8m0 casting correctly implements block-wise quantization with power-of-2 scaling factors and proper padding.
tests/unittest/_torch/modules/test_fused_moe.py (2)
12-13: LGTM!The imports are properly organized and follow the existing pattern.
Also applies to: 29-30
385-551: LGTM!The test comprehensively validates the DeepGemmFusedMoE implementation with FP8 blockwise quantization:
- Correctly uses FP8 e8m0 quantization for both weights and activations
- Implements a detailed reference implementation for accurate comparison
- Tests multiple sequence lengths for robustness
- Properly restricted to Blackwell+ GPUs with the decorator
tensorrt_llm/_torch/models/modeling_deepseekv3.py (4)
46-46: LGTM!The import is properly placed and follows the module organization pattern.
1295-1304: LGTM!The FP8 resmoothing logic correctly applies e8m0 quantization for SM 100 hardware when using FP8 block scales.
1211-1211: LGTM!Adding
.contiguous()after transpose operations ensures proper memory layout for subsequent operations.Also applies to: 1252-1252
1356-1376: LGTM!The dequantization logic correctly handles k_b_proj_trans and v_b_proj weights:
- Properly checks for parameter existence
- Correctly dequantizes using the weight_dequant function
- Ensures proper tensor shapes and dtypes
tensorrt_llm/quantization/utils/fp8_utils.py (4)
8-20: LGTM!The ceiling division implementation is correct and well-documented.
30-41: LGTM!The per-token FP8 e8m0 casting is correctly implemented with:
- Proper input validation
- Power-of-2 scaling factor computation
- Performance optimization with nvtx_range and torch.compile
43-69: LGTM!The per-block FP8 e8m0 casting correctly handles both 2D and 3D tensors with proper padding and block-wise quantization.
71-82: LGTM!The resmoothing function correctly converts FP8 weights to e8m0 format by:
- Dequantizing with existing scales
- Re-quantizing with power-of-2 scaling factors
- Supporting both 2D and 3D weight tensors
tensorrt_llm/_torch/modules/attention.py (2)
368-395: Clarify the temporary FP8 GEMM workaround for SM 100.The implementation replaces FP8 GEMM with regular
torch.bmmusing dequantized matrices on SM 100 hardware. The commented-out code suggests this is temporary. Please clarify:
- Is this due to FP8 GEMM support limitations on SM 100?
- What is the performance impact of using dequantized GEMM instead of FP8 GEMM?
- When will the FP8 GEMM implementation be restored?
Consider adding a TODO comment explaining the rationale and timeline for restoring the FP8 GEMM path.
1222-1228: LGTM! Correct integration of dequantized parameters.The dequantized parameters are correctly passed to
fp8_block_scaling_bmm_outin both BMM operations, maintaining consistency with the SM 100 workaround.Also applies to: 1277-1283
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
183-189: Consider documenting quantization mode limitations.The implementation only supports
deepseek_fp8_block_scalesquantization mode. Please add a comment explaining:
- Why only this mode is supported
- Plans for supporting other quantization modes
- Whether this is a fundamental limitation of the DeepGemm backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1362-1381: Review dequantization logic for potential issuesThe dequantization logic for
k_b_proj_trans_dequantandv_b_proj_dequantparameters looks functionally correct, but there are several considerations:
- CUDA tensor usage: The code correctly moves tensors to CUDA for the
weight_dequantfunction- Shape handling: The view operations and reshaping appear correct
- Data type conversion: The final
.to(dtype)conversion is appropriateHowever, there are potential concerns:
- Memory efficiency: Creating intermediate CUDA tensors and then converting back might be memory-intensive
- Error handling: No validation that the dequantization parameters exist before use
- Performance: Multiple view/reshape operations could be optimized
Consider adding null checks before dequantization:
- if attn_module.k_b_proj_trans_dequant is not None: + if (attn_module.k_b_proj_trans_dequant is not None and + k_b_proj_trans_scale is not None):Also consider batching the CUDA operations to improve memory efficiency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py(5 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(2 hunks)
🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)
7-7: Import addition looks good.The logger import is appropriately added to support the debug logging in the new
load_weightsmethod.
12-12: Import addition is correct.The
resmooth_to_fp8_e8m0function import is properly added to support FP8 resmoothing functionality.
468-481: In-place mutation ofweightsis consistent and safe
Otherload_weightsimplementations across the codebase routinely modify their localweightsdict in-place (e.g. in fused MoE backends, linear/quantization helpers, and model converters). In this context, each dict is created fresh for a single module’s loading, so mutating it does not leak side effects elsewhere. No copy is necessary.tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)
47-47: LGTM: Import addition for FP8 resmoothing utilityThe import of
resmooth_to_fp8_e8m0is correctly added to support the new FP8 resmoothing functionality implemented below.
1212-1212: Good: Adding contiguous() for memory layout optimizationThe
.contiguous()call ensures proper memory layout for the transposed tensor, which is important for performance in subsequent operations.
1253-1253: Good: Consistent contiguous() applicationSame optimization applied consistently in the dequantization helper function.
| if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( | ||
| ) and get_sm_version() == 100: | ||
| for name in list(weights.keys()): | ||
| # Use ".experts." to exclude shared_experts. | ||
| if name.endswith( | ||
| "weight_scale_inv") and ".experts." not in name: | ||
| weight_name = name.replace("weight_scale_inv", "weight") | ||
| logger.debug(f"Resmoothing {weight_name}") | ||
| weight = weights[weight_name][:] | ||
| scale = weights[name][:] | ||
| weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( | ||
| weight, scale) | ||
| weights[weight_name] = weights[weight_name].cpu() | ||
| weights[name] = weights[name].cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify FP8 resmoothing integration and performance impact
The FP8 resmoothing logic looks correct but has some considerations:
- Architecture check: The condition
get_sm_version() == 100correctly targets SM 100 GPUs - Weight filtering: The logic properly excludes shared experts using
".experts." not in name - Memory management: Weights are moved to CPU after resmoothing, which is good for memory efficiency
However, there are potential concerns:
- The resmoothing process may be computationally expensive and could impact loading time
- The CPU transfer after resmoothing might not be necessary if the weights are immediately used
🏁 Script executed:
#!/bin/bash
# Description: Verify the usage of resmooth_to_fp8_e8m0 function and check for any performance implications
# Check if resmooth_to_fp8_e8m0 is properly implemented
ast-grep --pattern 'def resmooth_to_fp8_e8m0($_) {
$$$
}'
# Check for other usages of this function in the codebase
rg -A 5 "resmooth_to_fp8_e8m0"
# Check if there are any tests for this functionality
rg -A 10 "test.*resmooth.*fp8" Length of output: 3138
Add unit tests for FP8 resmoothing & benchmark loading overhead
The FP8 resmoothing logic is correctly integrated—SM 100 check, expert filtering, and GPU→CPU transfer (to free GPU memory during load) are all intentional. However:
- No unit tests cover
resmooth_to_fp8_e8m0 - The impact on model‐loading time isn’t quantified
Please address the following:
• tests/quantization/test_fp8_utils.py (or equivalent): add tests for resmooth_to_fp8_e8m0 covering
– 2D and higher-dim tensors
– corner cases (e.g. zero scales, non-contiguous tensors)
• Benchmark script or CI performance test: measure load times with and without FP8 resmoothing to ensure overhead is acceptable
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py around lines 1296 to 1309,
the FP8 resmoothing logic lacks unit tests and performance benchmarks. Add unit
tests for the resmooth_to_fp8_e8m0 function in
tests/quantization/test_fp8_utils.py that cover 2D and higher-dimensional
tensors, zero scale values, and non-contiguous tensors. Additionally, create a
benchmark script or CI performance test to measure model loading times with and
without FP8 resmoothing to evaluate and ensure the overhead is acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)
57-57: Replace lambda expression with a proper function definition.Following PEP 8 guidelines, lambda expressions should not be assigned to variables.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
107-107: Replace lambda expression with a proper function definition.Following PEP 8 guidelines, lambda expressions should not be assigned to variables.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
239-239: Split long line for better readability.- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
295-295: Split long line for better readability.- expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition + expected_m = ( + (token_selected_experts.numel() + self.expert_size_per_partition - 1) + // self.expert_size_per_partition + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
57-57: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
107-107: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
239-239: Line too long (131 > 120)
(E501)
295-295: Line too long (124 > 120)
(E501)
🔇 Additional comments (9)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (9)
1-17: Import structure looks good.The imports are well-organized and all dependencies appear to be used appropriately in the implementation. The external
deep_gemmdependency and internal FP8 utilities are correctly imported for the DeepGemm FP8 MoE backend.
19-66: Triton kernel implementation looks correct.The
masked_index_copy_kerneland its wrapper function correctly implement masked copying based on expert token offsets. The indexing logic and memory access patterns are appropriate for the MoE use case.
68-116: Triton gather kernel implementation is sound.The
masked_index_gather_kerneland wrapper function correctly implement the inverse operation of the copy kernel. The tensor shape assertions provide good validation.
118-141: Helper functions are well-implemented.The helper functions correctly implement:
- SwiGLU activation with proper chunking and gating
- Masked indexing for token selection
- Metadata computation for expert token offsets
Good use of
@torch.compileand@nvtx_rangefor optimization and profiling.
143-159: DeepGemm wrapper function is correctly implemented.The function provides a clean interface to the
deep_gemm.fp8_m_grouped_gemm_nt_maskedoperation with proper tensor creation and argument passing. The bfloat16 output dtype is appropriate for the FP8 workflow.
161-210: Class definition and constructor are well-structured.The
DeepGemmFusedMoEclass properly inherits fromCutlassFusedMoEand delegates initialization to the parent class. The docstring clearly explains the backend components and custom operations used.
227-243: Routing and input processing logic is correct.The routing application, expert selection validation, and conditional router weight application are properly implemented with appropriate assertions and error handling.
304-323: FP8 GEMM workflow is well-implemented.The two-stage FP8 GEMM operations with SwiGLU activation in between correctly implement the MoE expert computation using the new FP8 utilities and DeepGemm backend.
325-347: Output gathering and finalization are correct.The final steps properly gather results from the padded tensor and use the custom finalization operation to produce the final hidden states with correct scaling and unpermutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
328-348: Address the duplicate variable creation issue.This code segment has the exact issue mentioned in the past review comments. The variable
permuted_data_tensor_paddedis created twice with different parameters, causing the first allocation to be immediately overwritten.Apply this diff to remove the duplicate allocation:
- max_padded_tokens = (x.shape[0] + 128) // 128 * 128 - permuted_data_tensor_padded = torch.empty( - (self.expert_size_per_partition, max_padded_tokens, - self.hidden_size), - dtype=self.dtype, - device='cuda') - masked_m, start_offsets = preprocess_after_permute( expert_first_token_offset_tensor) m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition permuted_data_tensor_padded = torch.empty( self.expert_size_per_partition, m_max, self.hidden_size, dtype=self.dtype, device='cuda')
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)
59-59: Replace lambda expressions with proper function definitions.Static analysis correctly identifies that lambda expressions assigned to variables should be rewritten as proper function definitions for better readability and debugging.
Apply this diff:
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )And similarly for line 109:
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )Also applies to: 109-109
283-283: Fix line length to comply with style guidelines.Line exceeds the 120-character limit.
Apply this diff:
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
59-59: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
109-109: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
283-283: Line too long (131 > 120)
(E501)
🔇 Additional comments (13)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (5)
21-46: LGTM: Well-structured Triton kernel for masked index copy.The kernel logic correctly handles memory access patterns with proper masking and bounds checking. The computation of indices and offsets is mathematically sound.
70-95: LGTM: Consistent Triton kernel implementation for masked index gather.The gather kernel mirrors the copy kernel structure appropriately, with correct index calculations and memory access patterns.
120-124: LGTM: Efficient SwiGLU implementation.The fused SwiGLU activation using
torch.compileis well-implemented and should provide good performance.
145-203: LGTM: Well-structured DeepGemm FP8 group blockwise GEMM wrapper.The function properly validates tensor shapes, strides, and data types. The transformation of scaling factors and the call to the underlying DeepGemm implementation are correctly structured.
350-372: LGTM: Efficient FP8 quantization and GEMM workflow.The forward pass correctly applies FP8 quantization using the utility functions, performs the two GEMM operations with the fused SwiGLU activation in between, and handles the data flow properly.
tensorrt_llm/quantization/utils/fp8_utils.py (8)
8-19: LGTM: Standard ceiling division implementation.The mathematical logic is correct and the function is properly documented.
22-28: LGTM: Utility functions are mathematically sound.Both
alignandceil_to_ue8m0functions implement correct mathematical operations for alignment and power-of-two ceiling calculations.
32-49: LGTM: Well-implemented per-token FP8 quantization.The function correctly handles both 2D and 3D tensors, applies proper blocking (128 elements), computes scaling factors using the FP8 range (448.0), and performs the quantization with appropriate clamping.
52-77: LGTM: Proper per-block FP8 quantization implementation.The function correctly:
- Pads tensors to multiples of 128 for proper block alignment
- Reshapes into 128x128 blocks for quantization
- Computes scaling factors per block
- Returns tensors cropped to original dimensions
80-90: LGTM: Effective weight resmoothing implementation.The function properly applies scaling factors by repeating them to match weight dimensions and calls the appropriate quantization function.
104-133: LGTM: Complex but correct TMA-aligned tensor transformation.The function correctly:
- Converts float tensors to uint8 by extracting exponent bits
- Handles proper padding and alignment for TMA requirements
- Performs necessary transposition for column-major layout
- Supports both 2D and 3D tensors
136-163: LGTM: Comprehensive scaling factor layout validation.The function performs thorough validation of tensor shapes, strides, and alignment requirements for TMA operations.
168-216: LGTM: Robust scaling factor layout transformation.The function handles multiple granularity cases and properly transforms scaling factors based on hardware requirements. The logic for different granularities (1,128) and (128,128) is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)
261-401: Well-orchestrated MoE forward pass implementation.The forward_chunk method properly handles the complex flow of routing, quantization, permutation, FP8 operations, and finalization. The logic appears correct and the duplicate tensor creation issue from previous reviews has been resolved.
288-288: Fix line length violation.The line exceeds the 120-character limit and should be broken up for better readability.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
60-60: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
112-112: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
288-288: Line too long (131 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)
1-19: Well-organized imports for FP8 MoE implementation.The imports are logically grouped and all appear necessary for the DeepGemm FP8 backend functionality.
21-69: Triton kernels are well-implemented with proper bounds checking.The masked index copy kernel correctly handles 3D tensor operations with expert-based indexing and includes appropriate masking for safety.
60-60: Static analysis false positive - lambda is appropriate here.The lambda expression for
gridis the standard pattern for Triton kernel launches and should not be changed to a function definition.
72-122: Gather kernel correctly implements the inverse of the copy operation.The masked index gather kernel properly handles the inverse transformation with consistent offset calculations and masking logic.
112-112: Static analysis false positive - lambda is appropriate here.Similar to the copy kernel, the lambda expression for
gridfollows the standard Triton pattern and should not be changed.
124-148: Helper functions are well-designed with appropriate optimizations.The SwiGLU activation, indexing, and preprocessing functions are correctly implemented with torch.compile decorators for performance optimization and NVTX ranges for profiling.
150-208: Excellent defensive programming in the DeepGemm wrapper.The function includes comprehensive type and shape assertions, proper scaling factor transformations, and clean separation of concerns. This robust implementation should prevent runtime errors.
210-259: Clean class definition with good documentation.The DeepGemmFusedMoE class properly inherits from CutlassFusedMoE with clear documentation of the custom operations involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)
61-61: Address static analysis hint about lambda expression.Static analysis suggests rewriting the lambda as a proper function definition for better readability.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
113-113: Address static analysis hint about lambda expression.Same lambda expression style issue as in the copy function.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
289-289: Address line length issue.The line exceeds the 120-character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
61-61: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
113-113: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
289-289: Line too long (131 > 120)
(E501)
🔇 Additional comments (10)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10)
1-20: LGTM - Well-organized imports for DeepGemm FP8 MoE implementation.The import structure is clean and includes all necessary dependencies for the DeepGemm-based FP8 blockwise MoE backend.
22-71: Triton kernel implementation looks correct.The masked index copy kernel properly handles memory access patterns with appropriate bounds checking and masking. The logic for computing token indices, offsets, and memory locations is sound.
73-122: Triton gather kernel implementation is correct.The masked index gather kernel properly implements the inverse operation of the copy kernel with consistent memory access patterns and bounds checking.
125-149: Helper functions are well-implemented.The fused SwiGLU activation, masked indexing, and preprocessing functions are correctly implemented with appropriate optimizations (torch.compile) and profiling annotations (NVTX ranges).
151-209: Comprehensive DeepGemm wrapper with proper validation.The function provides excellent input validation, proper scaling factor layout transformation, and clean interface to the underlying FP8 GEMM kernel. The extensive assertions help catch configuration errors early.
211-261: Well-designed class inheriting from CutlassFusedMoE.The class properly extends the base CutlassFusedMoE with appropriate initialization and parameter passing. The docstring clearly explains the custom ops used in the pipeline.
262-304: Proper routing and quantization setup.The forward method correctly applies routing, validates tensor shapes and types, and sets up the appropriate quantization mode for DeepSeek FP8 block scales.
305-337: Correct allgather and permutation logic.The distributed processing with allgather and the custom moe_permute_op integration are properly implemented with appropriate parameter passing.
338-383: Excellent FP8 GEMM pipeline implementation.The core computation pipeline is well-structured:
- Proper handling of empty tensors
- Correct tensor padding and masking operations
- FP8 casting with per-token quantization
- Two-stage GEMM operations with fused SwiGLU activation
- Proper tensor gathering back to original layout
This demonstrates a sophisticated understanding of FP8 quantization and efficient MoE computation.
384-404: Proper finalization with custom op integration.The final step correctly uses the moe_finalize_scale_op to combine results, apply scaling, and unpermute the output with all necessary parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)
46-69: Address lambda assignment style issueThe function is well-implemented with proper shape validation. However, consider replacing the lambda assignment with a proper function definition for better readability.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
97-121: Address lambda assignment style issueSimilar to the copy function, consider replacing the lambda assignment with a proper function definition.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
217-248: Handle line length and consider decompositionThe method has good initial validation and routing logic. However, line 244 exceeds the 120-character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
requirements.txt(1 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- requirements.txt
- tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
59-59: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
111-111: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
244-244: Line too long (131 > 120)
(E501)
🔇 Additional comments (10)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10)
1-18: LGTM - Well-organized importsThe imports are appropriate for the DeepGemm FP8 MoE implementation, with clear separation between external dependencies and internal modules.
20-44: Well-implemented Triton kernelThe
masked_index_copy_kernelcorrectly implements masked copying with proper bounds checking and memory access patterns. The logic for computing row/column indices and offsets is sound.
71-95: Complementary gather kernel implementationThe
masked_index_gather_kernelcorrectly implements the inverse operation of the copy kernel, with consistent bounds checking and memory access patterns.
123-147: Well-designed helper functionsThe helper functions are appropriately decorated and implemented:
swiglu_fused_moecorrectly implements the SwiGLU activation patternindexingensures contiguity for performancepreprocess_after_permuteefficiently computes expert mappings usingtorch.searchsorted
149-164: Clean wrapper for DeepGemm FP8 operationsThe function provides a well-structured interface to the DeepGemm library with appropriate output tensor allocation and parameter passing.
186-216: Standard initialization patternThe class initialization follows the expected pattern by calling the parent constructor with all necessary parameters.
249-293: Complex quantization and permutation logicThe quantization setup and moe_permute_op call are well-structured with appropriate parameter passing. The handling of different quantization modes and distributed processing is comprehensive.
294-314: Efficient memory management for padded tensorsThe implementation correctly handles empty tensor cases and efficiently manages memory allocation for padded tensors. The use of Triton masked copy is appropriate for performance.
315-335: Well-orchestrated FP8 GEMM computation flowThe dual GEMM operations with intermediate SwiGLU activation correctly implement the MoE computation pattern. The FP8 casting before each GEMM operation is appropriate for the DeepGemm backend.
336-359: Proper result gathering and finalizationThe output gathering using Triton and final scaling/unpermutation through moe_finalize_scale_op correctly completes the MoE computation pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (7)
tensorrt_llm/quantization/utils/fp8_utils.py (2)
104-133: Potential tensor creation inefficiency and magic numbers.The function creates multiple temporary tensors and uses hardcoded constants that should be documented.
The function could be optimized and made more readable:
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + """ + Convert FP32 tensor to column-major TMA-aligned packed uint8 format. + + Note: For extreme performance, consider rewriting this function in CUDA. + + Args: + x: Input tensor (2D or 3D) of dtype float32 + + Returns: + Packed tensor in column-major TMA-aligned layout + """ assert x.dtype == torch.float and x.dim() in (2, 3) - # First, convert into UE8M0 `uint8_t` + # Convert to UE8M0 format by extracting exponent bits (bits 23-30) ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
136-163: Improve error messages and add input validation.The function has good validation logic but could provide more helpful error messages.
def check_sf_layout(sf: torch.Tensor, mn: int, k: int, gran: Tuple[int, int], num_groups: Optional[int], tma_stride_check: bool = False, type_check: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Validate scaling factor tensor layout and constraints. + + Args: + sf: Scaling factor tensor to validate + mn: M*N dimension size + k: K dimension size + gran: Granularity tuple (m_gran, k_gran) + num_groups: Number of groups (None for 2D tensors) + tma_stride_check: Whether to validate TMA alignment + type_check: Expected dtype (None to skip check) + + Returns: + The input tensor (for chaining) + + Raises: + AssertionError: If layout constraints are violated + """ # Type check if type_check is not None: - assert sf.dtype == type_check + assert sf.dtype == type_check, f"Expected dtype {type_check}, got {sf.dtype}" # Always do shape checks - assert sf.dtype in (torch.float, torch.int) - assert sf.dim() == int(num_groups is not None) + 2 + assert sf.dtype in (torch.float, torch.int), f"Unsupported dtype {sf.dtype}" + expected_dims = int(num_groups is not None) + 2 + assert sf.dim() == expected_dims, f"Expected {expected_dims} dimensions, got {sf.dim()}"tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (5)
48-70: Fix static analysis issue: replace lambda with function definition.The lambda assignment violates Python style guidelines.
- # launch kernel - grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) - masked_index_copy_kernel[grid](output, + # launch kernel + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + + masked_index_copy_kernel[grid](output,
99-122: Fix static analysis issue: replace lambda with function definition.Same lambda assignment issue as in the copy function.
- # launch kernel - grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) - masked_index_gather_kernel[grid](output, + # launch kernel + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + + masked_index_gather_kernel[grid](output,
138-148: Consider adding input validation for the mask.The indexing function lacks validation for the mask tensor.
@nvtx_range("[DG] indexing") @torch.compile(dynamic=True) def indexing(x, mask): + """Select rows from x where mask > 0.""" + assert mask.dim() == 1, "Mask must be 1D" + assert mask.shape[0] == x.shape[0], "Mask length must match first dimension of x" return x[mask > 0, :].contiguous()
287-293: Fix line length issue and improve readability.Line 289 exceeds the 120-character limit and the comment indicates a temporary workaround.
if self.apply_router_weight_on_input: - assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" - assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert self.routing_method.top_k == 1, \ + "Current workaround only supports top-1 routing" + assert x.dtype != torch.float8_e4m3fn, \ + "Current workaround for apply_router_weight_on_input does not support fp8 input" x = x * token_final_scales.to(x.dtype) # TODO: remove this once we have correct fusedmoe kernel ready token_final_scales = None
295-304: Add validation for quantization configuration.The error handling could be more informative about supported quantization modes.
# quantize inputs use_deepseek_fp8_block_scale = False x_sf = None if self.has_any_quant: if self.has_deepseek_fp8_block_scales: use_deepseek_fp8_block_scale = True else: raise ValueError( - f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" + f"Unsupported quantization mode for DeepGemm backend: {self.quant_config.quant_mode}. " + f"Only FP8 block scaling is supported." )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
requirements.txt(1 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- requirements.txt
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
61-61: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
113-113: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
289-289: Line too long (131 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/quantization/utils/fp8_utils.py (2)
8-19: LGTM! Clean integer math utility.The ceiling division implementation is correct and well-documented.
22-23: LGTM! Simple alignment utility.Correctly implements alignment using the ceiling division helper.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)
22-46: LGTM! Well-implemented Triton kernel with proper bounds checking.The masked index copy kernel correctly handles bounds checking and memory access patterns.
73-97: LGTM! Correct implementation of masked gather kernel.The kernel properly implements the reverse operation of the copy kernel with appropriate bounds checking.
125-135: LGTM! Efficient SwiGLU implementation.The fused SwiGLU activation is correctly implemented with proper tensor chunking and activation.
151-208: Comprehensive GEMM wrapper with good validation.The function has excellent input validation and error checking. The extensive assertions ensure type safety and shape compatibility.
339-341: Good early exit optimization.Properly handles the edge case where no tokens are selected for processing.
385-403: Complex operation orchestration looks correct.The final scaling and unpermutation operation properly handles all the tensor transformations and expert routing logic.
| act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( | ||
| permuted_data_tensor_padded) | ||
| h1 = deepgemm_fp8_group_blockwise_gemm( | ||
| a=act_input_fp8, | ||
| b=self.w3_w1_weight, | ||
| sfa=act_input_sf, | ||
| sfb=self.quant_scales[0], | ||
| masked_m=masked_m, | ||
| expected_m=expected_m, | ||
| ) | ||
| h2 = swiglu_fused_moe(h1) | ||
| act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) | ||
| h3 = deepgemm_fp8_group_blockwise_gemm( | ||
| a=act_input_fp8, | ||
| b=self.w2_weight, | ||
| sfa=act_input_sf, | ||
| sfb=self.quant_scales[1], | ||
| masked_m=masked_m, | ||
| expected_m=expected_m, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
❓ Verification inconclusive
Verify FP8 quantization numerical stability.
The two consecutive FP8 quantization steps could potentially accumulate numerical errors.
The implementation quantizes activations twice (before each GEMM operation). Verify that this doesn't introduce significant numerical degradation:
🏁 Script executed:
#!/bin/bash
# Description: Look for numerical validation tests for double FP8 quantization
rg -A 10 -B 5 "per_token_cast_to_fp8_e8m0.*per_token_cast_to_fp8_e8m0" --type pyLength of output: 82
Add numerical-stability tests for consecutive FP8 quantization
We couldn’t find any existing tests that validate the error introduced by doing two back-to-back per_token_cast_to_fp8_e8m0 calls in fused_moe_deepgemm.py. Please add numerical‐validation tests to ensure the two-stage FP8 quantization doesn’t accumulate unacceptable error:
• Target file: tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
• Add tests that:
- Quantize a known tensor once vs. twice and compare against a high-precision reference
- Report maximum and RMS error, asserting they remain within acceptable bounds
• Place tests under your existing test suite (e.g.tests/…) and name them to clearly reflect “double FP8 quantization” verification
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py around lines 360
to 379, add numerical-stability tests to validate the error from consecutive FP8
quantizations using per_token_cast_to_fp8_e8m0. Create tests that quantize a
known tensor once and twice, then compare both results against a high-precision
reference tensor. Measure and report maximum and RMS errors, asserting they stay
within acceptable limits. Implement these tests in the existing test suite under
the tests directory with clear names indicating they verify double FP8
quantization stability.
| def ceil_to_ue8m0(x: torch.Tensor): | ||
| return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Potential numerical instability with very small values.
The function doesn't handle edge cases where x.abs() could be zero or very close to zero, which would cause issues with torch.log2().
Consider adding a small epsilon to prevent numerical issues:
-def ceil_to_ue8m0(x: torch.Tensor):
- return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
+def ceil_to_ue8m0(x: torch.Tensor):
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10))))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def ceil_to_ue8m0(x: torch.Tensor): | |
| return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) | |
| def ceil_to_ue8m0(x: torch.Tensor): | |
| return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10)))) |
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 26 to 27, the
function ceil_to_ue8m0 does not handle cases where x.abs() is zero or near zero,
causing torch.log2 to fail or produce unstable results. Fix this by adding a
small epsilon value to x.abs() before applying torch.log2 to ensure numerical
stability and prevent errors.
| def per_block_cast_to_fp8_e8m0( | ||
| x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| if x.dim() == 2: | ||
| m, n = x.shape | ||
| x_padded = torch.zeros((align(m, 128), align(n, 128)), | ||
| dtype=x.dtype, | ||
| device=x.device) | ||
| x_padded[:m, :n] = x | ||
| x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) | ||
| x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) | ||
| sf = ceil_to_ue8m0(x_amax / 448.0) | ||
| x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) | ||
| return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( | ||
| x_view.size(0), x_view.size(2)) | ||
| else: | ||
| g, m, n = x.shape | ||
| x_padded = torch.zeros((g, align(m, 128), align(n, 128)), | ||
| dtype=x.dtype, | ||
| device=x.device) | ||
| x_padded[:, :m, :n] = x | ||
| x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128) | ||
| x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4) | ||
| sf = ceil_to_ue8m0(x_amax / 448.0) | ||
| x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) | ||
| return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view( | ||
| x_view.size(0), x_view.size(1), x_view.size(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add input validation and improve error handling.
The function lacks proper input validation and error handling for edge cases.
+def per_block_cast_to_fp8_e8m0(
+ x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize tensor to FP8 e4m3fn format with per-block scaling.
+
+ Args:
+ x: Input tensor (2D or 3D), must be on CUDA device
+
+ Returns:
+ Tuple of (quantized_tensor, scaling_factors)
+ """
+ assert x.is_cuda, "Input tensor must be on CUDA device"
+ assert x.dtype in (torch.float16, torch.bfloat16, torch.float32), "Input must be floating point"
+ assert x.dim() in (2, 3), "Input tensor must be 2D or 3D"Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 52 to 77, the
function per_block_cast_to_fp8_e8m0 lacks input validation and error handling.
Add checks to ensure the input tensor x has the expected number of dimensions
(either 2 or 3) and raise informative errors if not. Also, validate that the
tensor sizes are compatible with the padding and reshaping operations to prevent
runtime errors. Include try-except blocks or conditional checks to handle edge
cases gracefully and provide clear error messages.
| def resmooth_to_fp8_e8m0(weight: torch.Tensor, | ||
| sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| weight = weight.cuda() | ||
| sf = sf.cuda() | ||
| if weight.dim() == 2: | ||
| x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( | ||
| 128, dim=1)[:weight.shape[0], :weight.shape[1]] | ||
| else: | ||
| x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( | ||
| 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] | ||
| return per_block_cast_to_fp8_e8m0(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Verify CUDA placement and add error handling.
The function forcibly moves tensors to CUDA without checking if CUDA is available, and lacks proper error handling.
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Re-quantize weights using existing scaling factors.
+
+ Args:
+ weight: Weight tensor to re-quantize
+ sf: Existing scaling factors
+
+ Returns:
+ Tuple of (re-quantized_weight, new_scaling_factors)
+ """
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA is required for FP8 operations")
weight = weight.cuda()
sf = sf.cuda()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def resmooth_to_fp8_e8m0(weight: torch.Tensor, | |
| sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| weight = weight.cuda() | |
| sf = sf.cuda() | |
| if weight.dim() == 2: | |
| x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( | |
| 128, dim=1)[:weight.shape[0], :weight.shape[1]] | |
| else: | |
| x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( | |
| 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] | |
| return per_block_cast_to_fp8_e8m0(x) | |
| def resmooth_to_fp8_e8m0(weight: torch.Tensor, | |
| sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Re-quantize weights using existing scaling factors. | |
| Args: | |
| weight: Weight tensor to re-quantize | |
| sf: Existing scaling factors | |
| Returns: | |
| Tuple of (re-quantized_weight, new_scaling_factors) | |
| """ | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for FP8 operations") | |
| weight = weight.cuda() | |
| sf = sf.cuda() | |
| if weight.dim() == 2: | |
| x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( | |
| 128, dim=1)[:weight.shape[0], :weight.shape[1]] | |
| else: | |
| x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( | |
| 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] | |
| return per_block_cast_to_fp8_e8m0(x) |
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 80 to 90, the
function forcibly moves tensors to CUDA without verifying CUDA availability and
lacks error handling. Modify the code to first check if CUDA is available before
moving tensors to CUDA, and add try-except blocks to catch and handle potential
errors during this operation, ensuring graceful fallback or informative error
messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)
111-119: Samelambda-grid issue in gather helperApply the same
def grid(meta): …pattern here for consistency and to satisfy Ruff.
357-368: Numerical-stability tests for back-to-back FP8 quantization still missingPrevious review already requested unit tests validating error accumulation across the two consecutive
per_token_cast_to_fp8_e8m0calls; no such tests are present yet.
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)
59-67: Replace thelambdaused for Triton grid sizingUsing
lambdaassignments trips Ruff’s E731 and slightly hurts readability/debuggability.
A tiny helper keeps the intent clear and eliminates the style violation.- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
268-271: Redundant self-assignment ofoutput_dtype
output_dtype = output_dtypeis a no-op and can be dropped.- assert output_dtype is not None - output_dtype = output_dtype + assert output_dtype is not None
285-285: Line exceeds 120 charactersBreak the long assertion for compliance with the repo’s 120-char guideline.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
345-355: Potentially large temporary allocation
permuted_data_tensor_paddedis sized(experts_per_partition, m_max, hidden_size)wherem_maxis ceiling-padded to 128.
For large batches & hidden sizes this can momentarily double memory usage.If DeepGemm tolerates ragged
M, consider allocating only up tomasked_m.max()instead of the padded upper bound to reduce peak memory.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
requirements.txt(1 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- requirements.txt
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
59-59: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
111-111: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
285-285: Line too long (131 > 120)
(E501)
🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
124-132:torch.compilemay silently fall back or error on unsupported envs
torch.compile(dynamic=True)is great for PT≥2.1 + supported back-ends, but on earlier versions or when Triton / float8 isn’t available it raises at import time.
Consider wrapping the decoration in a feature-check (e.g.if torch.__version__ >= "2.1"or an env flag) to preserve library importability for downstream users who only perform inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
420-427: Consecutive FP8 quantization numerical stability needs verification.This issue was previously flagged but remains unaddressed. The two consecutive FP8 quantization steps (lines 420 and 421-427) could potentially accumulate numerical errors.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)
122-171: Fix lambda assignment style issue.The gather kernel logic is correct, but there's a style issue that should be addressed.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) - masked_index_gather_kernel[grid](output, + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + masked_index_gather_kernel[grid](output,
336-336: Fix line length violation.- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, \ + "Current workaround for apply_router_weight_on_input does not support fp8 input"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
162-162: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
336-336: Line too long (131 > 120)
(E501)
🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)
1-18: Import structure looks clean and well-organized.The imports are appropriate for a DeepGemm-based MoE implementation with FP8 quantization. The dependency on
deep_gemmandfp8_utilsaligns with the module's purpose.
20-75: Triton kernel implementation follows proper FP8 quantization patterns.The kernel correctly implements masked index copying with group-wise FP8 quantization:
- Proper handling of program IDs and memory masking
- Correct FP8 E4M3 range using 448.0 constant
- Race condition prevention for scaling factor storage
- Power-of-2 scaling factor computation
77-120: Wrapper function properly validates inputs and launches the Triton kernel.The implementation includes:
- Comprehensive input validation with clear assertions
- Proper grid size and warp calculation using Triton best practices
- Correct parameter passing to the underlying kernel
174-198: Helper functions are clean and efficiently implement their specific operations.The TorchScript compiled functions are well-designed:
- Proper use of NVTX ranges for profiling
- Clear, focused functionality for each helper
- Efficient implementations using appropriate PyTorch operations
200-256: Excellent implementation with comprehensive validation.The DeepGemm function is thoroughly implemented with:
- Extensive input validation covering shapes, strides, and dtypes
- Proper scaling factor transformation using fp8_utils
- Appropriate tensor layout assertions
- Clean integration with the deep_gemm library
258-454: Well-structured MoE implementation with comprehensive functionality.The DeepGemmFusedMoE class effectively orchestrates:
- Router application and input quantization
- Permutation and distributed processing
- FP8 group quantization with custom kernels
- DeepGemm GEMM operations
- SwiGLU activation and output gathering
- Proper finalization and scaling
The implementation handles edge cases well and integrates cleanly with the existing MoE infrastructure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
tensorrt_llm/quantization/utils/fp8_utils.py (5)
28-29: Address numerical instability in ceil_to_ue8m0.This is similar to a past review comment but still applies. The function will fail when
x.abs()contains zero values due totorch.log2(0)being undefined.
32-51: Add comprehensive documentation for per_token_cast_to_fp8_e8m0.This addresses the past review comment about documenting the function and its requirements.
54-79: Improve input validation in per_block_cast_to_fp8_e8m0.This relates to the past review comment about adding proper input validation.
82-92: Add CUDA availability check in resmooth_to_fp8_e8m0.This addresses the past review comment about verifying CUDA availability.
168-218: Add comprehensive documentation for transform_sf_into_required_layout.This addresses the past review comment about improving documentation and error handling.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (4)
20-75: Consider adding bounds checking and documentation for the Triton kernel.The
_masked_index_copy_group_quant_fp8kernel performs complex indexing operations without bounds checking on several computed indices (row_idx,col_idx,elem_idx). While thevalidmask provides some protection, additional bounds checks could prevent potential out-of-bounds memory access.+ """ + Triton kernel for masked index copying with FP8 group quantization. + + Performs element-wise FP8 quantization in groups and copies data to output + tensor based on computed row/column indices from permutation metadata. + """ # get program id and block offset pid = tl.program_id(0) block_start = pid * group_size
162-162: Replace lambda with def for better readability.The static analysis correctly identifies that lambda expressions assigned to variables should be replaced with proper function definitions for better readability and debugging.
- grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + def grid(meta): + return (triton.cdiv(total_elems, meta['BLOCK_SIZE']), )
336-336: Fix line length violation.The line exceeds the 120-character limit as flagged by the linter.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
419-434: Document the activation tensor reallocation pattern.The code reallocates
act_input_fp8andact_input_sftensors between the two GEMM operations with different shapes. This pattern should be documented to clarify why reallocation is necessary rather than reusing buffers.+ # Reallocate activation tensors for second GEMM with different dimensions + # Shape changes from [E, M, 2*intermediate] to [E, M, intermediate] after SwiGLU act_input_fp8 = torch.empty(h1.shape[0], h1.shape[1], h1.shape[2] // 2,tensorrt_llm/quantization/utils/fp8_utils.py (2)
304-309: Fix docstring formatting issues.The static analysis correctly identifies formatting problems in the docstring.
-def silu_and_mul_masked_post_quant_fwd( - input: torch.Tensor, - output: torch.Tensor, - output_scale: torch.Tensor, - quant_group_size: int, - masked_m: torch.Tensor, - scale_ue8m0: bool = False, -): - """ - input shape [expert_num, token_num_padded, hidden_dim] - output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 - output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 - quant_group_size int, - masked_m shape [expert_num], - """ +def silu_and_mul_masked_post_quant_fwd( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, +): + """ + Apply SiLU activation and multiplication with masked FP8 post-quantization. + + Args: + input: Input tensor with shape [expert_num, token_num_padded, hidden_dim] + output: Output tensor with shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 + output_scale: Scale tensor with shape [expert_num, token_num_padded, hidden_dim // 2 // 128], dtype float32 + quant_group_size: Quantization group size (int) + masked_m: Mask tensor with shape [expert_num] + scale_ue8m0: Whether to use UE8M0 scaling + """
275-278: Optimize SiLU calculation in Triton kernel.The current implementation manually computes sigmoid using
1 / (1 + tl.exp(-gate)), which is less numerically stable than using a dedicated SiLU implementation. Consider using a more stable formulation.- gate = tl.load( - input_ptr_offs + token_index * stride_input_1 + size_n, - mask=offs_in_d < size_n, - other=0.0, - ).to(tl.float32) - gate = gate / (1 + tl.exp(-gate)) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_n, + mask=offs_in_d < size_n, + other=0.0, + ).to(tl.float32) + # Use more numerically stable SiLU: x * sigmoid(x) + gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
162-162: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
336-336: Line too long (131 > 120)
(E501)
tensorrt_llm/quantization/utils/fp8_utils.py
304-309: 1 blank line required between summary line and description
(D205)
304-309: First line should end with a period, question mark, or exclamation point
Add closing punctuation
(D415)
🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
392-409: Ignore padding alignment concern—both use 128
Both them_max = (x.shape[0] + 127)//128*128calculation and thegroup_size=128parameter employ the same 128-element alignment. There is no mismatch to document.Likely an incorrect or invalid review comment.
| def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: | ||
| # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA | ||
| assert x.dtype == torch.float and x.dim() in (2, 3) | ||
|
|
||
| # First, convert into UE8M0 `uint8_t` | ||
| ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) | ||
|
|
||
| # Second, make padded packed tensors | ||
| mn, k = x.shape[-2], x.shape[-1] | ||
| remove_dim = False | ||
| if x.dim() == 2: | ||
| x, remove_dim = x.unsqueeze(0), True | ||
| b = x.shape[0] | ||
| aligned_mn = get_tma_aligned_size(mn, 4) | ||
| aligned_k = align(k, 4) | ||
| padded = torch.zeros((b, aligned_mn, aligned_k), | ||
| device=x.device, | ||
| dtype=torch.uint8) | ||
| padded[:, :mn, :k] = ue8m0_tensor | ||
| padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, | ||
| aligned_k // 4) | ||
|
|
||
| # Finally, transpose | ||
| transposed = torch.transpose( | ||
| torch.empty((b, aligned_k // 4, aligned_mn), | ||
| device=x.device, | ||
| dtype=torch.int), 1, 2) | ||
| transposed[:, :, :] = padded | ||
| aligned_x = transposed[:, :mn, :] | ||
| return aligned_x.squeeze(0) if remove_dim else aligned_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
❓ Verification inconclusive
Verify bit manipulation safety in get_col_major_tma_aligned_packed_tensor.
The function performs direct bit manipulation on FP32 tensors (x.view(torch.int) >> 23) to extract exponent bits for UE8M0 format. This assumes IEEE 754 format and specific endianness, which may not be portable across all platforms.
🏁 Script executed:
#!/bin/bash
# Description: Check if there are any platform-specific assumptions or documentation about IEEE 754 requirements
rg -A 5 -B 5 "ieee.*754|IEEE.*754" --type py
rg -A 5 -B 5 "endian" --type pyLength of output: 78
Verify IEEE 754 and little-endian assumptions in get_col_major_tma_aligned_packed_tensor
This function relies on reinterpreting a float32 tensor as a 32-bit integer and shifting (x.view(torch.int) >> 23) to extract exponent bits. That approach assumes:
- IEEE 754 single-precision floats
- Little-endian byte order
I didn’t find any existing checks or documentation for these requirements in the codebase, so on big-endian or non-IEEE 754 platforms this will silently produce wrong results.
Recommendations:
- At the top of the function, add a runtime guard:
import sys assert sys.byteorder == "little", ( "get_col_major_tma_aligned_packed_tensor requires a little-endian platform" )
- In the module or function docstring, note that it only supports IEEE 754 32-bit floats and little-endian architectures.
- If you need to support other platforms, consider adding a conversion step or alternative implementation.
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 106 to 135, the
function get_col_major_tma_aligned_packed_tensor assumes IEEE 754 32-bit floats
and little-endian byte order when reinterpreting the float tensor as int and
shifting bits. To fix this, add a runtime assertion at the start of the function
to check that sys.byteorder is "little" and raise an error if not. Also, update
the function or module docstring to clearly state that it only supports IEEE 754
single-precision floats and little-endian architectures. If broader platform
support is needed, implement alternative handling or conversion for other cases.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
408-446: Consecutive FP8 quantizations may accumulate numerical error.The implementation performs FP8 quantization twice (lines 408-413 for first GEMM input, lines 433-438 for second GEMM input after SwiGLU). This aligns with previous review concerns about numerical stability.
Consider adding numerical validation tests to ensure the double quantization doesn't introduce unacceptable error accumulation, as suggested in previous reviews.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
340-340: Fix line length violation.Line exceeds the 120-character limit as flagged by static analysis.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, \ + "Current workaround for apply_router_weight_on_input does not support fp8 input"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
340-340: Line too long (131 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)
1-18: Import structure looks well-organized.The imports are appropriately structured for a DeepGemm FP8 MoE backend, including the necessary deep_gemm, triton, and FP8 utilities.
20-75: Triton kernel implementation follows best practices.The
_masked_index_copy_group_quant_fp8kernel correctly implements group-wise FP8 quantization with proper bounds checking and race condition protection for scaling factor writes.
77-120: Wrapper function is well-implemented with proper validation.The function provides comprehensive input validation and efficient kernel launch parameters with appropriate heuristics for performance optimization.
122-176: Gather kernel and wrapper are efficiently implemented.The
masked_index_gather_kerneluses block-based processing and early returns for optimal performance, while the wrapper provides appropriate validation.
178-202: Helper functions are well-designed and efficient.The TorchScript-compiled functions provide focused functionality with appropriate dynamic compilation and profiling annotations.
204-260: Comprehensive GEMM function with thorough validation.The
deepgemm_fp8_group_blockwise_gemmfunction provides extensive validation and proper scaling factor transformations for reliable FP8 GEMM operations.
262-312: Class initialization properly delegates to parent.The constructor appropriately inherits from CutlassFusedMoE and provides comprehensive documentation of the backend's custom operations.
314-471: Complex but well-structured forward pass implementation.The
forward_chunkmethod orchestrates a sophisticated MoE pipeline with proper handling of routing, quantization, distributed processing, and dual GEMM operations. The implementation correctly integrates all the custom operations and kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
428-433: Verify numerical stability of consecutive FP8 quantization.The implementation performs FP8 quantization twice in the forward pass - once for the input activations and again after the SwiGLU activation. This could potentially accumulate numerical errors that affect model quality.
Based on previous review feedback, this concern about consecutive FP8 quantization accumulating numerical errors remains valid. Consider adding numerical validation tests to ensure the double quantization doesn't introduce unacceptable error levels.
Also applies to: 453-458
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)
20-77: Consider numerical stability in FP8 quantization.The Triton kernel implementation is well-structured for parallel processing, but the quantization logic on lines 68-72 could benefit from additional numerical stability measures:
# quantization _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) output_s = _absmax / fp8_max -output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) +# Add small epsilon to prevent log2(0) and improve numerical stability +output_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(output_s), eps))))
88-91: Fix assertion message for clarity.The assertion message contains incorrect grammar.
assert ( input.shape[-1] % group_size == 0 -), "the last dimension of `input` cannot be divisible by `group_size`" +), "the last dimension of `input` must be divisible by `group_size`"
360-360: Fix line length violation.Line exceeds the 120 character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
360-360: Line too long (131 > 120)
(E501)
🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)
1-18: LGTM! Imports are well-organized and necessary.The imports are appropriately structured with standard library imports first, followed by third-party dependencies, and finally internal TensorRT-LLM modules.
142-171: LGTM! Well-implemented Triton gather kernel.The kernel efficiently handles masked index gathering with proper bounds checking and block-based processing for optimal memory access patterns.
173-195: LGTM! Clean and well-validated wrapper function.The function properly validates tensor shapes and dimensions before launching the Triton kernel.
198-222: LGTM! Well-optimized helper functions.The functions are appropriately decorated for performance profiling and optimization with
@torch.compile(dynamic=True)where beneficial.
224-279: LGTM! Comprehensive validation and proper DeepGemm integration.The function includes thorough tensor validation and correctly transforms scaling factors for the DeepGemm backend requirements.
282-491: LGTM! Well-structured MoE implementation with comprehensive functionality.The class successfully implements a complete DeepGemm-based FP8 MoE backend with proper distributed processing support, tensor routing, and quantization workflows.
caf55b8 to
e766164
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (7)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1296-1309: Unit tests and performance benchmarks still needed for FP8 resmoothingThe FP8 resmoothing logic is correctly implemented with proper filtering and SM 100 targeting. However, as identified in previous reviews, this functionality still lacks comprehensive unit tests and performance benchmarks to validate correctness and measure loading overhead.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
428-433: Consecutive FP8 quantization needs numerical validationThe implementation performs consecutive FP8 quantizations which could accumulate numerical errors. As noted in previous reviews, this pattern needs validation to ensure acceptable precision loss.
Also applies to: 453-458
tensorrt_llm/quantization/utils/fp8_utils.py (5)
28-29: Potential numerical instability with very small values.The function doesn't handle edge cases where
x.abs()could be zero or very close to zero, which would cause issues withtorch.log2().Consider adding a small epsilon to prevent numerical issues:
-def ceil_to_ue8m0(x: torch.Tensor): - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs().clamp(min=1e-10))))
54-79: Add input validation and improve error handling.The function lacks proper input validation and error handling for edge cases.
+def per_block_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor to FP8 e4m3fn format with per-block scaling. + + Args: + x: Input tensor (2D or 3D), must be on CUDA device + + Returns: + Tuple of (quantized_tensor, scaling_factors) + """ + assert x.is_cuda, "Input tensor must be on CUDA device" + assert x.dtype in (torch.float16, torch.bfloat16, torch.float32), "Input must be floating point" + assert x.dim() in (2, 3), "Input tensor must be 2D or 3D"
82-92: Verify CUDA placement and add error handling.The function forcibly moves tensors to CUDA without checking if CUDA is available, and lacks proper error handling.
def resmooth_to_fp8_e8m0(weight: torch.Tensor, sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Re-quantize weights using existing scaling factors. + + Args: + weight: Weight tensor to re-quantize + sf: Existing scaling factors + + Returns: + Tuple of (re-quantized_weight, new_scaling_factors) + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for FP8 operations") weight = weight.cuda() sf = sf.cuda()
106-135: Verify IEEE 754 and little-endian assumptions in get_col_major_tma_aligned_packed_tensorThis function relies on reinterpreting a
float32tensor as a 32-bit integer and shifting (x.view(torch.int) >> 23) to extract exponent bits. That approach assumes:
- IEEE 754 single-precision floats
- Little-endian byte order
Recommendations:
- At the top of the function, add a runtime guard:
import sys assert sys.byteorder == "little", ( "get_col_major_tma_aligned_packed_tensor requires a little-endian platform" )- In the module or function docstring, note that it only supports IEEE 754 32-bit floats and little-endian architectures.
- If you need to support other platforms, consider adding a conversion step or alternative implementation.
218-218: Replace assert with proper exception handling.The function uses
assert Falsefor error handling, which should be replaced with a proper exception that provides clear, informative error messages.- assert False, f'Unknown cases: {sf.dtype=}, {gran=}' + raise ValueError(f'Unsupported scaling factor layout: dtype={sf.dtype}, granularity={gran}')
🧹 Nitpick comments (5)
tests/unittest/_torch/modules/test_fused_moe.py (3)
385-401: Comprehensive test parameterization with appropriate hardware requirements.The test is well-parameterized to cover various scenarios. However, note that the large parameter space (8 different sequence lengths) may result in long test execution times. Consider if all combinations are necessary or if a subset would provide adequate coverage.
447-452: Question: Redundant scale storage?Both
weight_scale_invandweight_scaleare stored with the same values. Is this intentional for compatibility with different backends, or could this be simplified?
472-500: Improve readability with comments explaining the complex logic.The grouped GEMM function contains complex dequantization logic with hardcoded values (e.g., 128 for block size). Consider adding comments to explain:
- The significance of the 128 block size
- The dequantization process with scale factor repetition
- The overall grouped GEMM algorithm
This would improve maintainability and make the reference implementation easier to understand.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
360-360: Fix line length violationLine 360 exceeds the 120 character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )tensorrt_llm/quantization/utils/fp8_utils.py (1)
305-311: Fix docstring formatting issues.The docstring has formatting issues that violate Python documentation standards.
def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, scale_ue8m0: bool = False, ): """ - input shape [expert_num, token_num_padded, hidden_dim] - output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 - output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 - quant_group_size int, - masked_m shape [expert_num], + Fused SiLU activation, multiplication, and masked FP8 post-quantization. + + Args: + input: Input tensor with shape [expert_num, token_num_padded, hidden_dim] + output: Output tensor with shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 + output_scale: Output scaling factors [expert_num, token_num_padded, hidden_dim // 2 // 128], dtype float32 + quant_group_size: Quantization group size (int) + masked_m: Mask tensor with shape [expert_num] + scale_ue8m0: Whether to use UE8M0 scaling format """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
examples/llm-api/quickstart_advanced.py(1 hunks)examples/models/core/llama/README.md(2 hunks)jenkins/L0_MergeRequest.groovy(1 hunks)jenkins/L0_Test.groovy(3 hunks)requirements.txt(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(5 hunks)tensorrt_llm/_torch/modules/attention.py(6 hunks)tensorrt_llm/_torch/modules/fused_moe/create_moe.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(2 hunks)tensorrt_llm/_torch/modules/linear.py(3 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(2 hunks)tensorrt_llm/_utils.py(2 hunks)tensorrt_llm/quantization/utils/__init__.py(1 hunks)tensorrt_llm/quantization/utils/fp8_utils.py(1 hunks)tests/unittest/_torch/helpers.py(2 hunks)tests/unittest/_torch/modules/test_fused_moe.py(3 hunks)tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- jenkins/L0_MergeRequest.groovy
- examples/models/core/llama/README.md
🚧 Files skipped from review as they are similar to previous changes (11)
- requirements.txt
- tensorrt_llm/quantization/utils/init.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- examples/llm-api/quickstart_advanced.py
- tests/unittest/_torch/helpers.py
- tensorrt_llm/_utils.py
- tensorrt_llm/_torch/modules/fused_moe/create_moe.py
- tensorrt_llm/_torch/modules/linear.py
- tensorrt_llm/_torch/modules/fused_moe/quantization.py
- tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
- tensorrt_llm/_torch/modules/attention.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
360-360: Line too long (131 > 120)
(E501)
tensorrt_llm/quantization/utils/fp8_utils.py
304-309: 1 blank line required between summary line and description
(D205)
304-309: First line should end with a period, question mark, or exclamation point
Add closing punctuation
(D415)
🔇 Additional comments (20)
jenkins/L0_Test.groovy (5)
1114-1114: LGTM: Script path consolidationThe move from
tests/integration/defs/test_rerun.pytojenkins/test_rerun.pyis a sensible reorganization that places Jenkins-specific utilities in a more logical location.
1187-1189: LGTM: Improved test result traceabilityThe addition of XML modification to replace generic "pytest" testsuite names with actual stage names enhances test report clarity and traceability. The error handling with
|| trueis appropriate.
1195-1195: LGTM: Consistent script path updateThe script path change is consistent with the earlier modification on line 1114, ensuring all references point to the new location.
1203-1203: LGTM: Final consistent script path updateThe script path change completes the consistent update across all three usages of the test_rerun.py script in the rerunFailedTests function.
1648-1649: LGTM: Enhanced error message for better user experienceThe improvement to the error message in
checkStageNameSetsignificantly enhances usability by:
- Sorting available stage names for easier scanning
- Adding proper indentation for better readability
- Including the total count of available options
- Providing cleaner formatting overall
This will help users quickly identify valid stage names when they encounter validation errors.
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)
15-15: LGTM! Improved error handling robustness.Broadening the exception handling from
ModuleNotFoundErrortoImportErroris a good improvement that catches a wider range of import-related failures beyond just missing modules, such as compilation errors or dependency issues that may arise with the new FP8 integration.tests/unittest/_torch/modules/test_fused_moe.py (3)
12-13: LGTM! Necessary imports for FP8 testing.The new imports for FP8 utility functions and DeepGemmFusedMoE are appropriate for the new test functionality.
Also applies to: 29-30
402-421: LGTM! Proper test setup with deterministic patterns.The test setup correctly uses deterministic seeding and patterns to ensure reproducible results and avoid false positive failures.
545-550: LGTM! Proper test execution with appropriate tolerances.The test correctly compares the DeepGemm implementation against the reference with reasonable tolerances for FP8 quantization testing.
tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)
41-41: LGTM: Necessary imports for FP8 functionalityThe imports for
loggerandresmooth_to_fp8_e8m0are correctly added to support the FP8 resmoothing logic introduced in the weight loading method.Also applies to: 47-47
1212-1212: LGTM: Ensuring memory contiguity after transposeAdding
.contiguous()after transpose operations is a good practice to ensure optimal memory layout for subsequent tensor operations, particularly important for performance-critical paths.Also applies to: 1253-1253
1362-1381: LGTM: Well-implemented conditional dequantization logicThe dequantization logic properly:
- Checks for parameter existence before execution
- Uses CUDA tensors for performance
- Applies correct reshaping and dtype conversion
- Handles both k_b_proj_trans_dequant and v_b_proj_dequant parameters consistently
This supports the alternative FP8 batched matrix multiplication path for SM 100 hardware.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (8)
1-18: LGTM: Comprehensive imports for DeepGemm MoE implementationThe imports are well-organized and include all necessary dependencies for the DeepGemm-based fused MoE backend functionality.
20-77: LGTM: Well-implemented Triton kernel for FP8 group quantizationThe kernel correctly implements:
- Group-wise FP8 quantization with proper scaling
- Masked indexing with bounds checking
- Multi-stage processing for performance optimization
- Proper handling of tensor strides and memory access patterns
The complexity is appropriate for the performance-critical MoE operations.
79-140: LGTM: Robust Python wrapper with proper validation and tuningThe wrapper function provides:
- Comprehensive input validation ensuring correct tensor shapes and memory layout
- Dynamic performance tuning based on token count (different block sizes and staging)
- Proper FP8 quantization parameter setup
- Clean kernel launch with appropriate grid configuration
This ensures both correctness and performance optimization.
142-196: LGTM: Efficient gather kernel implementationThe masked index gather kernel and wrapper are well-implemented with:
- Proper bounds checking and validation
- Block-based processing for memory efficiency
- Clean tensor shape assertions
- Appropriate grid configuration
The implementation is straightforward and correct for the gather operation requirements.
198-222: LGTM: Well-designed helper functions with performance optimizationThe helper functions are cleanly implemented:
swiglu_fused_moe: Correct SwiGLU activation implementationindexing: Efficient masked indexing with contiguous outputpreprocess_after_permute: Proper token-to-expert mapping computationThe use of
@torch.compile(dynamic=True)and NVTX profiling annotations is appropriate for performance-critical MoE operations.
224-280: LGTM: Robust DeepGemm integration with comprehensive validationThe function provides excellent implementation with:
- Thorough tensor shape, stride, and dtype validation
- Proper scaling factor layout transformation using fp8_utils
- Correct DeepGemm API usage with masked grouped GEMM
- Clear assertions ensuring tensor memory layout requirements
The comprehensive validation ensures correctness while the DeepGemm integration provides high-performance FP8 computation.
282-331: LGTM: Well-designed class structure with informative documentationThe
DeepGemmFusedMoEclass is properly structured:
- Inherits from
CutlassFusedMoEto maintain consistent interface- Comprehensive docstring explaining the multi-op backend composition
- Constructor correctly forwards all parameters to parent class
- Clear parameter documentation for the MoE layer configuration
The design maintains compatibility while enabling DeepGemm-specific optimizations.
333-491: LGTM: Comprehensive MoE forward pass implementationThe forward method correctly orchestrates the complete MoE pipeline:
- Proper routing and expert selection
- Efficient permutation and data layout transformations
- FP8 quantization with group-wise scaling
- DeepGemm GEMM operations with masked computation
- SwiGLU activation with integrated quantization
- Final scaling and result assembly
The implementation demonstrates good understanding of the performance-critical MoE operations and integrates well with the TensorRT-LLM infrastructure.
| d = torch.empty((a.shape[0], b.shape[1]), | ||
| device=b.device, | ||
| dtype=torch.bfloat16) | ||
| m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) | ||
| for idx in range(offset_array.numel() - 1): | ||
| m_indices[offset_array[idx]:offset_array[idx + 1]] = idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix duplicate variable declarations.
The variables d and m_indices are declared twice in the grouped_gemm function. Remove the duplicate declarations.
- num_groups, n, k_ = b.shape
- d = torch.empty((a.shape[0], b.shape[1]),
- device=b.device,
- dtype=torch.bfloat16)
- m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
- for idx in range(offset_array.numel() - 1):
- m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
+ num_groups, n, k_ = b.shape📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| d = torch.empty((a.shape[0], b.shape[1]), | |
| device=b.device, | |
| dtype=torch.bfloat16) | |
| m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) | |
| for idx in range(offset_array.numel() - 1): | |
| m_indices[offset_array[idx]:offset_array[idx + 1]] = idx | |
| num_groups, n, k_ = b.shape |
🤖 Prompt for AI Agents
In tests/unittest/_torch/modules/test_fused_moe.py around lines 483 to 488, the
variables d and m_indices are declared twice within the grouped_gemm function.
Locate and remove the duplicate declarations of these variables to ensure each
is declared only once, preventing redundancy and potential conflicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
447-485: Verify numerical stability of consecutive FP8 quantizations.The implementation performs multiple FP8 quantization operations which could potentially accumulate numerical errors. This was flagged in a previous review but tests were not added.
Add numerical stability tests to validate the error from consecutive FP8 quantizations:
#!/bin/bash # Description: Look for existing numerical validation tests for double FP8 quantization rg -A 10 -B 5 "numerical.*stability|double.*fp8|consecutive.*quantiz" tests/ --type py
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
383-383: Fix line length violation.Line exceeds the 120 character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
383-383: Line too long (131 > 120)
(E501)
🔇 Additional comments (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (6)
20-86: LGTM: Well-implemented Triton kernel for FP8 quantization.The kernel correctly implements masked index copy with group-wise FP8 quantization, including proper scaling factor computation and bitwise packing for efficiency.
88-162: LGTM: Well-designed Python wrapper with proper validation and optimization.The function provides comprehensive input validation, workload-adaptive kernel tuning, and correct handling of tensor layouts and padding requirements.
165-194: LGTM: Clean and efficient gather kernel implementation.The kernel correctly implements masked index gathering with proper bounds checking and block-based processing for efficiency.
196-218: LGTM: Appropriate wrapper for the gather kernel.The function provides necessary input validation and uses a reasonable fixed block size for the gather operation.
221-244: LGTM: Well-designed helper functions with performance optimizations.The functions are properly compiled and instrumented with NVTX ranges, implementing their respective operations efficiently.
247-302: LGTM: Robust GEMM implementation with comprehensive validation.The function provides thorough input validation, proper scaling factor transformation, and correct integration with the deep_gemm library.
| from typing import List, Optional, Union | ||
|
|
||
| import deep_gemm | ||
| import torch | ||
| import torch.nn.functional as F | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils | ||
| from tensorrt_llm._utils import nvtx_range | ||
|
|
||
| from ...distributed import allgather | ||
| from ...model_config import ModelConfig | ||
| from ...utils import Fp4QuantizedTensor | ||
| from .fused_moe_cutlass import CutlassFusedMoE | ||
| from .quantization import MoEWeightLoadingMode | ||
| from .routing import BaseMoeRoutingMethod | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add required NVIDIA copyright header.
According to the coding guidelines, all TensorRT-LLM source files must contain an NVIDIA copyright header with the current year.
Add the copyright header at the top of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import List, Optional, Union📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from typing import List, Optional, Union | |
| import deep_gemm | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils | |
| from tensorrt_llm._utils import nvtx_range | |
| from ...distributed import allgather | |
| from ...model_config import ModelConfig | |
| from ...utils import Fp4QuantizedTensor | |
| from .fused_moe_cutlass import CutlassFusedMoE | |
| from .quantization import MoEWeightLoadingMode | |
| from .routing import BaseMoeRoutingMethod | |
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Optional, Union | |
| import deep_gemm | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils | |
| from tensorrt_llm._utils import nvtx_range | |
| from ...distributed import allgather | |
| from ...model_config import ModelConfig | |
| from ...utils import Fp4QuantizedTensor | |
| from .fused_moe_cutlass import CutlassFusedMoE | |
| from .quantization import MoEWeightLoadingMode | |
| from .routing import BaseMoeRoutingMethod |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py at the very top
before any imports, add the required NVIDIA copyright header with the current
year as per coding guidelines. This header should clearly state NVIDIA's
copyright ownership and the year to comply with project standards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1297-1310: Add unit tests for FP8 resmoothing & benchmark loading overheadtensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (2)
1-18: Add required NVIDIA copyright header.
441-471: Add numerical-stability tests for consecutive FP8 quantization
🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1363-1382: Consider extracting dequantization logic to reduce duplicationThe dequantization logic for
k_b_proj_transandv_b_projis duplicated. Consider extracting it into a helper function for better maintainability.+def dequantize_projection(weight, scale, target_shape, target_dtype): + """Helper function to dequantize projection weights.""" + return weight_dequant( + weight.view(-1, weight.shape[-1]).cuda(), + scale.view(-1, scale.shape[-1]).cuda(), + ).view(*target_shape).to(target_dtype) if attn_module.k_b_proj_trans_dequant is not None: - attn_module.k_b_proj_trans_dequant.data.copy_( - weight_dequant( - k_b_proj_trans.view( - -1, k_b_proj_trans.shape[-1]).cuda(), - k_b_proj_trans_scale.view( - -1, - k_b_proj_trans_scale.shape[-1]).cuda(), - ).view( - *attn_module.k_b_proj_trans_dequant.shape). - to(attn_module.k_b_proj_trans_dequant.dtype)) + attn_module.k_b_proj_trans_dequant.data.copy_( + dequantize_projection( + k_b_proj_trans, + k_b_proj_trans_scale, + attn_module.k_b_proj_trans_dequant.shape, + attn_module.k_b_proj_trans_dequant.dtype + )) if attn_module.v_b_proj_dequant is not None: - attn_module.v_b_proj_dequant.data.copy_( - weight_dequant( - v_b_proj.view(-1, - v_b_proj.shape[-1]).cuda(), - v_b_proj_scale.view( - -1, v_b_proj_scale.shape[-1]).cuda(), - ).view(*attn_module.v_b_proj_dequant.shape).to( - attn_module.v_b_proj_dequant.dtype)) + attn_module.v_b_proj_dequant.data.copy_( + dequantize_projection( + v_b_proj, + v_b_proj_scale, + attn_module.v_b_proj_dequant.shape, + attn_module.v_b_proj_dequant.dtype + ))tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)
78-79: Add comment explaining the scale factor packing logicThe bitcast and bit shift operation extracts the FP32 exponent efficiently, but this optimization deserves a clarifying comment.
- output_s = output_s.to(tl.int32, bitcast=True) >> 23 - output_s_int32 += output_s << (group_index * 8) + # Extract FP32 exponent by bitcasting to int32 and shifting right by 23 + output_s = output_s.to(tl.int32, bitcast=True) >> 23 + # Pack 4 group scales into a single int32 (8 bits per scale) + output_s_int32 += output_s << (group_index * 8)
88-95: Add comprehensive docstring for better documentationThe function lacks documentation explaining its purpose, parameters, and return value.
def masked_index_copy_group_quant_fp8( output: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, group_size: int, eps: float = 1e-10, ): + """ + Perform masked index copy with group-wise FP8 quantization. + + Args: + output: Output tensor for quantized values, shape [num_experts, max_tokens, hidden_size] + input: Input tensor to be quantized and copied, shape [num_tokens, hidden_size] + start_offsets: Start token offsets for each expert, shape [num_experts + 1] + row_indices: Expert indices for each token, shape [num_tokens] + group_size: Size of quantization groups (must divide hidden_size) + eps: Small epsilon value for numerical stability in quantization + + Returns: + torch.Tensor: Scale factors for the quantized values, shape [num_experts, actual_tokens, scale_dims] + """
377-377: Split long line for better readabilityLine exceeds the 120 character limit.
- assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py(6 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(2 hunks)tensorrt_llm/_torch/modules/linear.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/modules/linear.py
- tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
377-377: Line too long (131 > 120)
(E501)
🔇 Additional comments (3)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)
1213-1213: Good practice: Ensuring contiguous memory layout after transposeAdding
.contiguous()after transpose operations ensures optimal memory access patterns for subsequent operations, which is particularly important for CUDA kernels.Also applies to: 1254-1254
1431-1441: FP8 scale transformation correctly implementedThe scale transformation logic properly:
- Checks for FP8 block scales and SM 100 architecture
- Uses the correct recipe (1, 128, 128) matching the block size
- Creates non-trainable parameters as expected
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
247-297: Well-implemented FP8 grouped GEMM wrapperThe function provides comprehensive validation and correctly handles scale factor transformations for DeepGemm operations.
2c908f3 to
3a1a0b8
Compare
|
/bot run |
|
PR_Github #12927 [ run ] triggered by Bot |
|
PR_Github #12927 [ run ] completed with state |
|
/bot run |
|
/bot run |
|
PR_Github #13021 [ run ] triggered by Bot |
|
PR_Github #13021 [ run ] completed with state |
|
/bot run |
|
PR_Github #13082 [ run ] triggered by Bot |
|
PR_Github #13082 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #13122 [ run ] triggered by Bot |
|
PR_Github #13122 [ run ] completed with state |
…o avoid OOM (#25) Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: Barry Kang <[email protected]> Signed-off-by: Fanrong Li <[email protected]>
This reverts commit fe01f02. Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: Yuxian Qiu <[email protected]> Signed-off-by: Fanrong Li <[email protected]>
972ab86 to
59b3957
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #13192 [ run ] triggered by Bot |
|
PR_Github #13192 [ run ] completed with state |
* Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27) This reverts commit 9107cfa. * Remove compile for reducing warnings Signed-off-by: Barry Kang <[email protected]> --------- Signed-off-by: Barry Kang <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #13217 [ run ] triggered by Bot |
|
PR_Github #13217 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #13254 [ run ] triggered by Bot |
|
PR_Github #13254 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #13268 [ run ] triggered by Bot |
|
PR_Github #13268 [ run ] completed with state |
|
Closing, since #6486 already merged. |
PR title
Please write the PR title by following template:
[JIRA ticket link/nvbug link/github issue link][fix/feat/doc/infra/...] <summary of this PR>
For example, assume I have a PR hope to support a new feature about cache manager of Jira TRTLLM-1000 ticket, it would be like
[TRTLLM-1000][feat] Support a new feature about cache manager
Description
Please explain the issue and the solution in short.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Bug Fixes
Chores
Tests