Skip to content

Conversation

danielafrimi
Copy link
Collaborator

@danielafrimi danielafrimi commented Jul 14, 2025

W4A8 GEMM

  1. Support running w4a8_awq quantized model from modelopt.

  2. Kernel support multiple gemm with mixed dtype

Summary by CodeRabbit

  • New Features

    • Added support for W4A8 AWQ quantization, enabling 4-bit weight and 8-bit activation quantized linear layers.
    • Introduced a generalized fine-grained mixed-dtype GEMM operator with configurable activation/output dtypes and scaling factor.
    • Added new test coverage for W4A8 AWQ quantization and the new mixed-dtype GEMM operator.
  • Bug Fixes

    • Corrected tensor dimension handling and scaling logic in quantized linear methods.
  • Refactor

    • Unified and renamed GEMM runner classes and custom operators for broader mixed-dtype support.
    • Updated tests and modules to use the new mixed-dtype GEMM API.
  • Tests

    • Added comprehensive tests for W4A8 AWQ quantization and fine-grained mixed-dtype GEMM.
    • Removed legacy W4A16 GEMM test, updating others for the new API.

Signed-off-by: Daniel Afrimi <[email protected]>
@danielafrimi danielafrimi requested a review from a team as a code owner July 14, 2025 13:18
@danielafrimi
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11818 [ run ] triggered by Bot

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Jul 14, 2025
@tensorrt-cicd
Copy link
Collaborator

PR_Github #11818 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #8758 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Collaborator

@Naveassaf Naveassaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Daniel

Had some annoying comments - didnt find any bugs :)

Signed-off-by: Daniel Afrimi <[email protected]>
Copy link
Contributor

coderabbitai bot commented Jul 17, 2025

Walkthrough

The changes introduce a generalized fine-grained mixed-dtype GEMM runner and operator, replacing the previous W4A16-specific implementation. The new runner supports configurable activation and output data types, an alpha scaling parameter, and extended quantization methods including W4A8 AWQ. The Python API, C++ backend, and test suites are updated to reflect these enhancements, and new tests are added for W4A8 functionality.

Changes

File(s) Change Summary
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp, .h Renamed class to finegrainedMixedDtypeGemmRunner, added output dtype and alpha support, updated logic.
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py Replaced W4A16GemmRunner with FinegrainedMixedDtypeGemm, updated custom op and autotuning logic.
tensorrt_llm/_torch/modules/linear.py Added W4A8 AWQ quantization support, refactored weight handling, and integrated new GEMM operator.
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py New tests for FinegrainedMixedDtypeGemm covering various quantization and dtype scenarios.
tests/unittest/_torch/thop/test_w4a8_linear.py New test for W4A8 AWQ quantized linear layer and GEMM operator.
tests/unittest/_torch/thop/test_w4a16_linear.py Updated to use FinegrainedMixedDtypeGemm and new operator interface.
tests/unittest/_torch/thop/test_w4a16_gemm.py Deleted legacy W4A16 GEMM operator test.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant LinearModule
    participant FinegrainedMixedDtypeGemm
    participant CUDAOp

    User->>LinearModule: Forward(input)
    LinearModule->>FinegrainedMixedDtypeGemm: forward(input, weight, scales, group_size, ...)
    FinegrainedMixedDtypeGemm->>CUDAOp: finegrained_mixed_dtype_gemm(input, weight, scales, group_size, output_dtype, alpha, ...)
    CUDAOp-->>FinegrainedMixedDtypeGemm: Output Tensor
    FinegrainedMixedDtypeGemm-->>LinearModule: Output Tensor
    LinearModule-->>User: Output Tensor
Loading

Suggested reviewers

  • achartier
  • omera-nv
  • juney-nvidia
  • Fridah-nv

Poem

🐇
A hop and a skip, new GEMMs arrive,
With dtypes mixed, our ops now thrive!
W4A8 and W4A16, both in the race,
Alpha scaling, new tests in place.
Old code is gone, new flows begin—
Fine-grained precision, let’s leap right in!

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Signed-off-by: Daniel Afrimi <[email protected]>
@danielafrimi
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/unittest/_torch/thop/test_w4a8_linear.py (1)

28-29: Good test coverage with asymmetric dimensions.

The test now properly validates the case where input_dim != output_dim (128 != 512), addressing the previous concern.

🧹 Nitpick comments (4)
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (1)

3-3: Consider using absolute imports for better maintainability.

The relative import from utils.util import could break if the test file is moved. Consider using an absolute import path.

-from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul
+from tests.unittest._torch.thop.utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul
tensorrt_llm/_torch/modules/linear.py (3)

1086-1086: Consider breaking this long line for better readability.

-            or input.dtype,  # NOTE: output_dtype can only be bf16/fp16 for W4A8
+            or input.dtype,  # output_dtype: bf16/fp16 only for W4A8

1190-1191: Address line length for better readability.

-        # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale
-        # Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal)
+        # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer 
+        # shared the same input and create an avg pre_quant_scale.
+        # Usually when modelopt exports the quantized model, pre_quant_scale is fused 
+        # in the layer norm (this case relevant if fused is disabled - modelopt internal)

1237-1237: Break this long comment line.

-            # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16)
+            # NOTE: Create this tensor in load_weights, since not all layers have this 
+            # tensor and memory is not allocated for it (same as W4A16)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dbf2918 and 128ec57.

📒 Files selected for processing (8)
  • cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (3 hunks)
  • cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1 hunks)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (3 hunks)
  • tensorrt_llm/_torch/modules/linear.py (8 hunks)
  • tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (1 hunks)
  • tests/unittest/_torch/thop/test_w4a16_gemm.py (0 hunks)
  • tests/unittest/_torch/thop/test_w4a16_linear.py (3 hunks)
  • tests/unittest/_torch/thop/test_w4a8_linear.py (1 hunks)
💤 Files with no reviewable changes (1)
  • tests/unittest/_torch/thop/test_w4a16_gemm.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
  • FinegrainedMixedDtypeGemm (678-717)
  • finegrained_mixed_dtype_gemm (722-764)
tensorrt_llm/_torch/modules/linear.py (6)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (7)
  • finegrained_mixed_dtype_gemm (722-764)
  • _ (216-255)
  • _ (334-342)
  • _ (423-433)
  • _ (605-632)
  • _ (665-675)
  • _ (884-971)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
  • create_weights (165-173)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
  • create_weights (296-304)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (9)
  • create_weights (72-85)
  • create_weights (227-239)
  • create_weights (251-283)
  • create_weights (432-464)
  • create_weights (551-624)
  • create_weights (810-865)
  • create_weights (1065-1070)
  • create_weights (1173-1185)
  • apply (181-188)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
  • create_weights (139-147)
tensorrt_llm/quantization/mode.py (2)
  • is_int4_weight_only_per_group (129-130)
  • QuantAlgo (23-44)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py

114-114: Line too long (135 > 120)

(E501)


1098-1098: Line too long (132 > 120)

(E501)


1190-1190: Line too long (143 > 120)

(E501)


1191-1191: Line too long (165 > 120)

(E501)


1237-1237: Line too long (142 > 120)

(E501)

🔇 Additional comments (13)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)

6-85: Changes look good!

The migration from W4A16GemmRunner to FinegrainedMixedDtypeGemm is implemented correctly with proper keyword arguments in the operator call.

cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1)

27-44: Header changes are well-structured.

The addition of outputDtype parameter and alpha scaling factor with a sensible default value properly extends the interface for mixed dtype support.

cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (2)

47-143: Constructor implementation properly handles mixed dtype combinations.

The validation logic and error messages are clear. Good support for Float8_e4m3fn activation with multiple output dtypes.


217-261: Output dtype handling and alpha parameter passing are correct.

The implementation properly uses mOutputDtype to determine the tensor output type and correctly handles the alpha scaling parameter.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

678-691: LGTM! Clean refactoring to support mixed dtype GEMM.

The class has been properly renamed and extended to support configurable activation and output data types alongside the quantization mode. The caching mechanism is correctly implemented.


720-764: Well-structured custom operator with proper validation.

The operator correctly validates the presence of zero point tensors based on the quantization mode and includes all necessary parameters for mixed dtype GEMM operations.

tensorrt_llm/_torch/modules/linear.py (7)

113-119: LGTM! Correct handling of both W4A16 and W4A8 AWQ methods.

The logic properly determines the activation dtype based on the quantization method and applies the necessary preprocessing for both AWQ variants.


916-924: Good migration to the generalized GEMM operator.

The change from w4a16_gemm to finegrained_mixed_dtype_gemm properly supports the unified mixed-dtype GEMM interface with appropriate parameters including output_dtype.


969-969: Note the consistent transpose pattern for weight scales.

The weight scale tensors are now transposed and made contiguous across all loading methods. This aligns with the expected layout for the new finegrained_mixed_dtype_gemm operator.

Also applies to: 986-987, 1008-1009


1054-1067: Well-documented quantization flow.

The docstring clearly explains the modelopt flow for w4a8_awq, which is helpful for understanding the implementation. The flow correctly handles pre_quant_scale multiplication, FP8 quantization, and output rescaling.


1071-1075: Good handling of pre-quantized FP8 inputs.

The code efficiently handles the case where the input is already in FP8 format, avoiding redundant quantization operations.


1262-1264: Correct integration of W4A8 AWQ support.

The new quantization method is properly integrated into the factory function following the established pattern.


1389-1394: LGTM! Property correctly identifies W4A8 AWQ quantization.

The implementation follows the established pattern for quantization method properties.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/linear.py (2)

1026-1026: Fix parameter reference in error message.

The error message should use the in_features parameter instead of module.in_features for consistency.

-                f"in_features ({module.in_features}) must be divisible by group_size ({group_size}) "
+                f"in_features ({in_features}) must be divisible by group_size ({group_size}) "

114-114: Fix line length violations.

Several lines exceed the 120-character limit. Consider breaking these lines for better readability.

Examples of how to fix:

-    if module.has_w4a16_awq or module.has_w4a8_awq:
+    if (module.has_w4a16_awq or module.has_w4a8_awq):
-    def load_weight_scales_w4a8(self,
+    def load_weight_scales_w4a8(
+            self,
             weights: List[Dict],
             tp_size: int = 1,
             tp_rank: int = 0,
-            tp_mode: Optional[TensorParallelMode] = None):
+            tp_mode: Optional[TensorParallelMode] = None
+    ):

Also applies to: 1093-1093, 1185-1186, 1232-1232

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 128ec57 and 95ca1ea.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/linear.py (8 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py

114-114: Line too long (135 > 120)

(E501)


1093-1093: Line too long (132 > 120)

(E501)


1185-1185: Line too long (143 > 120)

(E501)


1186-1186: Line too long (165 > 120)

(E501)


1232-1232: Line too long (142 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/modules/linear.py (5)

113-120: LGTM! Clean extension to support W4A8 AWQ.

The logic correctly determines the activation dtype based on the quantization method (float16 for W4A16, float8_e4m3fn for W4A8) and uses it appropriately in the weight preprocessing step.


916-924: LGTM! Successful migration to generalized GEMM operator.

The migration from the specialized W4A16 operator to finegrained_mixed_dtype_gemm is well-implemented. The parameter mapping is correct and the absence of the alpha parameter is appropriate for W4A16 AWQ.


1012-1238: Excellent implementation of W4A8 AWQ support.

The new W4A8_AWQ_LinearMethod class is well-structured and follows established patterns while properly handling FP8-specific requirements. The implementation includes:

  • Proper FP8 quantization with input_scale and alpha parameters
  • Correct weight scale handling with float16 dtype
  • Comprehensive scale computation logic
  • Clear documentation of the modelopt flow
  • Consistent error handling and validation

The class mirrors the W4A16 implementation while adding the necessary FP8-specific adaptations.


1257-1259: LGTM! Proper method selection logic.

The addition of W4A8 AWQ method selection follows the established pattern and correctly checks for the appropriate quantization algorithm.


1384-1388: LGTM! Consistent property addition.

The has_w4a8_awq property follows the same pattern as other quantization method properties and provides the necessary interface for detecting W4A8 AWQ quantization.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12210 [ run ] triggered by Bot

@Naveassaf Naveassaf self-requested a review July 17, 2025 14:12
Copy link
Collaborator

@Naveassaf Naveassaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12210 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9066 completed with status: 'FAILURE'

@Naveassaf
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12374 [ run ] triggered by Bot

@danielafrimi
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12375 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12374 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12375 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit 95ca1ea

Signed-off-by: Daniel Afrimi <[email protected]>
@danielafrimi
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)

120-120: Address line length violations for better code style.

Several lines exceed the 120-character limit. Consider breaking long lines for better readability:

-        if module.has_w4a16_awq or module.has_w4a8_awq:
+        if module.has_w4a16_awq or module.has_w4a8_awq:
-    def load_weight_scales_w4a8(self,
-                                weights: List[Dict],
+    def load_weight_scales_w4a8(
+            self,
+            weights: List[Dict],

Also applies to: 1106-1106, 1203-1204, 1257-1257

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 95ca1ea and 213f5df.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/linear.py (10 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py

120-120: Line too long (135 > 120)

(E501)


1106-1106: Line too long (132 > 120)

(E501)


1203-1203: Line too long (143 > 120)

(E501)


1204-1204: Line too long (165 > 120)

(E501)


1257-1257: Line too long (142 > 120)

(E501)

🔇 Additional comments (15)
tensorrt_llm/_torch/modules/linear.py (15)

50-54: Well-implemented helper method.

The flip method provides a clean way to get the orthogonal tensor parallel mode, which is needed for proper sharding of activation scales. The implementation is correct and the comment clearly explains its purpose.


119-125: Proper extension for W4A8 support.

The changes correctly extend the weight preprocessing logic to handle both W4A16 and W4A8 quantization methods by selecting the appropriate activation dtype (torch.float16 for W4A16 vs torch.float8_e4m3fn for W4A8).


902-902: Improved flexibility with dtype parameter.

Using dtype instead of hardcoded torch.float16 makes the weight scale creation more flexible and consistent with the module's data type.


918-918: Proper conditional application of pre-quantization scaling.

Moving the pre-quantization scaling inside the condition ensures it's only applied when the scale tensor exists.


922-930: Successfully migrated to generalized GEMM operator.

The transition from the specialized W4A16 GEMM to the generic finegrained_mixed_dtype_gemm operator is well-implemented, with proper parameter mapping and the addition of required parameters like output_dtype and alpha.


964-982: Correct pre-quantization scale handling.

The use of TensorParallelMode.flip(module.tp_mode) correctly handles the sharding of activation scales along the orthogonal dimension to the weights. The tensor creation and copying logic is also properly implemented.


999-999: Performance optimization with contiguous memory layout.

Adding .T.contiguous() ensures optimal memory layout for the concatenated weight scales, which can improve performance in subsequent operations.

Also applies to: 1021-1022


1027-1063: Well-structured weight creation for W4A8 AWQ.

The create_weights method properly sets up all required parameters for W4A8 quantization:

  • Quantized weights with correct shape and dtype
  • Float16 weight scales (appropriate for FP8 activations)
  • Input scaling parameters for FP8 quantization
  • Alpha parameter for output rescaling

The implementation follows established patterns and includes proper validation.


1065-1099: Excellent implementation with clear documentation.

The apply method is well-documented with a detailed docstring explaining the ModelOpt quantization flow. The implementation correctly:

  • Handles conditional pre-quantization scaling
  • Supports both pre-quantized FP8 inputs and dynamic quantization
  • Uses the generalized GEMM operator with proper parameter mapping
  • Includes appropriate output dtype constraints

1101-1138: Consistent and correct scale loading implementation.

The load_weight_scales_w4a8 method follows established patterns while properly handling W4A8-specific requirements:

  • Extracts all necessary scaling factors
  • Converts weight scales to appropriate float16 dtype
  • Includes proper consistency checks with assertions
  • Computes alpha correctly for output rescaling

1140-1175: Proper vanilla weight loading for W4A8.

The implementation correctly follows the W4A8 quantization requirements:

  • Uses appropriate helper functions for weight loading
  • Handles pre-quantization scales with correct tensor parallel sharding
  • Applies weight_scale_2 division as required by the quantization scheme
  • Sets up all scaling parameters properly

1177-1220: Comprehensive fused QKV weight loading.

The implementation properly handles fused QKV weights for W4A8:

  • Uses appropriate preprocessing with FP8 activation dtype
  • Correctly concatenates and processes weight scales with division
  • Includes optional pre-quantization scale handling with proper comments
  • Follows established patterns for fused weight operations

1222-1262: Consistent fused gate-up weight loading implementation.

The method correctly implements fused gate-up weight loading for W4A8:

  • Proper weight preprocessing and concatenation
  • Correct scale processing with weight_scale_2 division
  • Optional pre-quantization scale handling consistent with QKV implementation
  • Good code comments explaining behavior

1282-1284: Proper integration of W4A8 AWQ method selection.

The addition correctly extends the quantization method selection logic to handle W4A8 AWQ, following the established pattern for other quantization methods.


1409-1413: Consistent property implementation for W4A8 AWQ detection.

The has_w4a8_awq property follows the exact same pattern as has_w4a16_awq, providing consistent interface for detecting W4A8 AWQ quantization.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12376 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12376 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9195 completed with status: 'FAILURE'

@danielafrimi
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12377 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12377 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9196 completed with status: 'FAILURE'

@danielafrimi
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/linear.py (5)

119-125: LGTM! Correctly handles both AWQ quantization modes.

The logic properly differentiates between W4A16 and W4A8 AWQ modes and sets the appropriate activation dtype. The addition of .contiguous() addresses previous performance feedback.

Minor issue: Line 120 exceeds the 120-character limit. Consider breaking it into multiple lines for better readability.


1067-1101: Comprehensive apply method with clear documentation.

The method correctly implements the modelopt W4A8 AWQ flow:

  1. Pre-quantization scaling
  2. FP8 input quantization with proper fallback
  3. Mixed-dtype GEMM with alpha scaling

The docstring provides excellent clarity on the quantization pipeline.

Minor issue: Line length violations on lines 1096-1097. Consider breaking the comment for better readability.


1103-1140: Robust weight scale loading implementation.

The load_weight_scales_w4a8 method properly handles:

  • Shared scaling factors for concatenated weights
  • Proper tensor parallel sharding
  • Correct dtype conversions to float16
  • Alpha computation for GEMM output rescaling

The assertions ensure consistency across multiple weights.

Minor issue: Line 1108 exceeds character limit. Consider breaking the method signature.


1179-1223: Comprehensive fused QKV weight loading.

The method handles the complexities of fused QKV layers:

  • Proper weight preprocessing for mixed GEMM
  • Correct concatenation and scaling of weight scales
  • Conditional pre_quant_scale loading based on modelopt export settings
  • Appropriate tensor parallel mode flipping for activation scales

The comments explain the modelopt behavior regarding pre_quant_scale fusion.

Minor issues: Lines 1205-1206 exceed character limits.


1224-1264: Consistent fused gate-up weight loading.

The implementation mirrors the QKV approach with:

  • Proper weight preprocessing and concatenation
  • Correct scale handling and alpha computation
  • Conditional pre_quant_scale loading
  • Consistent memory layout optimization

Minor issue: Line 1259 exceeds character limit.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 213f5df and 3339381.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/linear.py (10 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py

120-120: Line too long (135 > 120)

(E501)


1108-1108: Line too long (132 > 120)

(E501)


1205-1205: Line too long (143 > 120)

(E501)


1206-1206: Line too long (165 > 120)

(E501)


1259-1259: Line too long (142 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (11)
tensorrt_llm/_torch/modules/linear.py (11)

50-54: LGTM! Well-documented helper method.

The flip method correctly implements the orthogonal tensor parallel mode logic needed for activation scale sharding. The docstring clearly explains its purpose.


904-904: Good generalization of dtype handling.

Removing the explicit dtype parameter makes the method more flexible and consistent with the generalized approach.


920-920: Performance optimization applied correctly.

The contiguous() call addresses previous review feedback about memory layout optimization during preprocessing.


924-932: LGTM! Updated to use the generalized mixed-dtype GEMM.

The function call correctly migrates from the W4A16-specific interface to the new finegrained_mixed_dtype_gemm with appropriate parameters including output_dtype and proper argument mapping.


965-984: Correct tensor parallel handling for activation scales.

The use of TensorParallelMode.flip(module.tp_mode) is appropriate since pre_quant_scale applies to activations rather than weights, requiring the orthogonal sharding dimension.


1001-1001: Memory layout optimization applied.

The .T.contiguous() call ensures optimal memory layout for the concatenated weight scale tensor.


1023-1023: Consistent memory layout optimization.

The .T.contiguous() call maintains consistency with other weight scale handling and optimizes memory layout.


1027-1066: Well-structured weight creation for W4A8 AWQ.

The create_weights method properly handles the FP8 activation requirements:

  • Uses float16 for weight scales (required for FP8 activation)
  • Creates necessary parameters for input scaling and alpha
  • Proper error handling for group_size divisibility

The implementation follows established patterns from W4A16 AWQ.


1142-1177: Correct vanilla weight loading with proper tensor parallel handling.

The implementation correctly:

  • Loads pre_quant_scale with flipped tensor parallel mode (for activation scaling)
  • Handles weight scale division by weight_scale_2
  • Maintains proper tensor shapes with transpose and contiguous operations
  • Sets inverse input scale appropriately

The dtype assertion ensures compatibility between pre_quant_scale and module dtype.


1284-1286: LGTM! Consistent pattern for W4A8 AWQ method selection.

The addition correctly identifies W4A8 AWQ configurations and returns the appropriate quantization method, following the established pattern from W4A16 AWQ.


1411-1415: LGTM! Consistent property for W4A8 AWQ detection.

The property correctly identifies W4A8 AWQ quantization mode, following the same pattern as has_w4a16_awq. This provides a clean interface for checking the quantization configuration.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12381 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12381 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9200 completed with status: 'SUCCESS'

@Naveassaf Naveassaf enabled auto-merge (squash) July 20, 2025 14:30
@Naveassaf Naveassaf merged commit 5300a99 into NVIDIA:main Jul 20, 2025
3 checks passed
reasonsolo pushed a commit to reasonsolo/TensorRT-LLM that referenced this pull request Jul 21, 2025
Signed-off-by: Daniel Afrimi <[email protected]>
timlee0212 pushed a commit to timlee0212/TensorRT-LLM that referenced this pull request Jul 21, 2025
Signed-off-by: Daniel Afrimi <[email protected]>
NVShreyas pushed a commit to NVShreyas/TensorRT-LLM that referenced this pull request Jul 28, 2025
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Ransiki pushed a commit to Ransiki/TensorRT-LLM that referenced this pull request Jul 29, 2025
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Ransiki Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Community want to contribute PRs initiated from Community
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants