-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-5863][feat] Support Weight-Only-Quantization in PyTorch Workflow #5850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRTLLM-5863][feat] Support Weight-Only-Quantization in PyTorch Workflow #5850
Conversation
|
/bot run |
|
PR_Github #11327 [ run ] triggered by Bot |
|
PR_Github #11327 [ run ] completed with state |
cf6bc94 to
c35c81b
Compare
|
/bot run |
|
PR_Github #11718 [ run ] triggered by Bot |
c35c81b to
cc62513
Compare
|
PR_Github #11718 [ run ] completed with state |
|
/bot run |
|
PR_Github #11720 [ run ] triggered by Bot |
|
PR_Github #11720 [ run ] completed with state |
6c53025 to
7393aae
Compare
|
/bot run |
|
PR_Github #12070 [ run ] triggered by Bot |
|
PR_Github #12070 [ run ] completed with state |
7393aae to
2f9bd01
Compare
WalkthroughA weight-only quantized GEMM (General Matrix Multiply) runner and operator were introduced, supporting INT8 and INT4 quantized weights with FP16/BF16 activations. This includes CUDA/C++ implementation, Python bindings, a new linear method for weight-only quantization, and comprehensive unit tests for both the GEMM operator and the linear layer integration. Changes
Sequence Diagram(s)sequenceDiagram
participant PythonUser
participant LinearModule
participant WeightOnlyQuantLinearMethod
participant WeightOnlyQuantGemmRunner (Python)
participant TorchScriptClass
participant CUDA_GEMM
PythonUser->>LinearModule: forward(input)
LinearModule->>WeightOnlyQuantLinearMethod: apply(input, bias)
WeightOnlyQuantLinearMethod->>WeightOnlyQuantGemmRunner (Python): forward(input, weight, weight_scale, tactic, to_userbuffers, out_dtype)
WeightOnlyQuantGemmRunner (Python)->>TorchScriptClass: run_gemm(input, weight, weight_scale, tactic, to_userbuffers, out_dtype)
TorchScriptClass->>CUDA_GEMM: launch CUDA GEMM kernel
CUDA_GEMM-->>TorchScriptClass: output tensor
TorchScriptClass-->>WeightOnlyQuantGemmRunner (Python): output tensor
WeightOnlyQuantGemmRunner (Python)-->>WeightOnlyQuantLinearMethod: output tensor
WeightOnlyQuantLinearMethod-->>LinearModule: output tensor
LinearModule-->>PythonUser: output tensor
Estimated code review effort3 (120 minutes) Suggested labels
Suggested reviewers
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
2f9bd01 to
4e55b3a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tensorrt_llm/_torch/modules/linear.py (2)
113-119: Verify the logic change from AWQ-specific to general weight-only quantization.The change from
module.has_w4a16_awqtomodule.has_weight_only_quantbroadens the condition significantly. This could affect AWQ models if they don't follow the same preprocessing path.This relates to the previous review comment about whether
has_weight_only_quantshould includehas_w4a16_awqor if they should remain separate checks.Fix the line length issue:
- weight = preprocess_weights_for_mixed_gemm( - weight.T.to(torch.int8).contiguous().cpu(), weight_dtype, - torch.float16).cuda().contiguous() + weight = preprocess_weights_for_mixed_gemm( + weight.T.to(torch.int8).contiguous().cpu(), + weight_dtype, + torch.float16 + ).cuda().contiguous()
1056-1056: Good fix addressing previous review feedback.This addresses the previous review comment about unnecessary
.toand.contiguous()calls. The input should already be properly typed and contiguous when passed to the apply method.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/linear.py (2)
900-934: Clean implementation with minor improvement opportunity.The
WeightOnlyQuantLinearMethodclass follows established patterns well. One minor suggestion for theapplymethod:The bias handling can be simplified:
- bias = bias.contiguous() if bias is not None else None + if bias is not None: + bias = bias.contiguous()This avoids unnecessary assignment when bias is None.
971-1014: Consider refactoring to reduce code duplication with AWQ method.The fused weight loading methods work correctly but share significant code patterns with
W4A16_AWQ_LinearMethod. Consider extracting common preprocessing logic into shared helper functions.The preprocessing logic in lines 979-981 and 1000-1002 is nearly identical to AWQ's implementation and could be refactored into a shared helper.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
cpp/tensorrt_llm/thop/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(1 hunks)tensorrt_llm/_torch/modules/linear.py(6 hunks)tensorrt_llm/quantization/functional.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_gemm.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_linear.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/tensorrt_llm/thop/CMakeLists.txt
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/quantization/functional.py
- tests/unittest/_torch/thop/test_weight_only_quant_linear.py
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h
- tests/unittest/_torch/thop/test_weight_only_quant_gemm.py
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp
- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
114-114: Line too long (135 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/modules/linear.py (3)
173-192: LGTM! Well-implemented helper function.The function provides a clean abstraction for determining weight dtype and packing ID based on quantization mode. The documentation is comprehensive and the error handling is appropriate.
1164-1166: LGTM! Correct factory method logic.The condition properly checks for weight-only quantization without per-group scaling, and the placement before the more specific AWQ check follows the correct precedence order.
1288-1292: LGTM! Consistent property implementation.The
has_weight_only_quantproperty follows the established pattern of other quantization type properties in the class.
|
/bot run |
|
PR_Github #12371 [ run ] triggered by Bot |
|
PR_Github #12371 [ run ] completed with state |
|
/bot run |
|
PR_Github #12379 [ run ] triggered by Bot |
|
PR_Github #12379 [ run ] completed with state |
4e55b3a to
3434e53
Compare
Signed-off-by: Yuening Li <[email protected]>
Signed-off-by: Yuening Li <[email protected]>
Signed-off-by: Yuening Li <[email protected]>
Signed-off-by: Yuening Li <[email protected]>
Signed-off-by: Yuening Li <[email protected]>
Signed-off-by: Yuening Li <[email protected]>
3434e53 to
eb69ef9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)
119-127: Good refactoring to use the consolidated weight-only quantization check.The change from checking specific AWQ variants to the more general
has_weight_only_quantproperty aligns well with the new architecture. The logic for determining activation dtype and using the new helper function is correct.However, there's a line length issue that should be addressed:
- # NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm + # NOTE: without the preprocess during the runtime, the gemm output nan's. + # In order to use the preprocess_weights_for_mixed_gemm
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
cpp/tensorrt_llm/thop/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(1 hunks)tensorrt_llm/_torch/modules/linear.py(5 hunks)tensorrt_llm/quantization/functional.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_gemm.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_linear.py(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- cpp/tensorrt_llm/thop/CMakeLists.txt
- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/unittest/_torch/thop/test_weight_only_quant_linear.py
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp
- tests/unittest/_torch/thop/test_weight_only_quant_gemm.py
- tensorrt_llm/quantization/functional.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
120-120: Line too long (135 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/_torch/modules/linear.py (8)
180-199: Well-designed helper function with clear documentation and proper validation.The function correctly determines weight dtype and packing ID for weight-only quantization modes. The assertions ensure proper usage context, and the error handling covers unsupported cases appropriately.
909-928: Properly implemented weight creation with correct tensor shapes.The
create_weightsmethod correctly handles weight packing considerations (INT4 weights packed into INT8) and creates appropriate parameter tensors for weights and scales.
930-940: Clean implementation of the apply method using the new CUDA operator.The method correctly determines weight dtype, ensures bias contiguity, and calls the new
weight_only_quant_gemmoperator with appropriate parameters.
942-966: Load weight scales method correctly handles fused QKV scenarios.The method properly loads weight scales for all three components (Q, K, V) with appropriate tensor parallel handling.
978-996: Fused QKV weight loading with proper preprocessing.The implementation correctly concatenates Q, K, V weights, applies necessary preprocessing using
preprocess_weights_for_mixed_gemm, and handles weight scales appropriately.
998-1020: Fused gate/up weight loading follows the same solid pattern.Consistent implementation with the QKV variant, properly handling weight concatenation, preprocessing, and scale loading.
1419-1421: Correct integration of the new weight-only quantization method.The condition properly identifies weight-only quantization without per-group scaling and returns the appropriate method instance.
1546-1550: Well-implemented property following established patterns.The
has_weight_only_quantproperty correctly checks for weight-only quantization and maintains consistency with other quantization properties in the class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/linear.py (3)
119-126: Address formatting issue and verify preprocessing logic.The conceptual improvements using
has_weight_only_quantaddress previous review feedback well. However, there's a formatting issue and some questions about the preprocessing logic.
Line length violation: Line 120 exceeds the 120-character limit as flagged by static analysis.
Verify preprocessing consistency: The preprocessing involves multiple transformations (
.T.to(torch.int8).contiguous().cpu()then back to.cuda().contiguous()). This seems inefficient - verify if all these operations are necessary.- # NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm - # we need to cast the weight to int8 first. + # NOTE: without the preprocess during the runtime, the gemm output nan's. + # in order to use the preprocess_weights_for_mixed_gemm we need to cast the weight to int8 first.Please verify if the CPU transfer and back to CUDA is intentional or can be optimized.
978-997: Consider refactoring complex fused QKV weight loading.The fused QKV weight loading method has several complex operations that could benefit from extraction into helper methods for better maintainability.
The method performs:
- Weight concatenation
- Weight preprocessing with dtype conversion
- Scale loading and concatenation
Consider extracting the preprocessing logic into a separate method:
def _preprocess_fused_weights(self, module: Linear, weights: torch.Tensor) -> torch.Tensor: weight_dtype, _ = get_weight_dtype_and_id(module) return preprocess_weights_for_mixed_gemm( weights.to(torch.int8).T.contiguous().cpu(), weight_dtype, torch.float16 ).cuda().contiguous()This would improve readability and reduce duplication between the fused methods.
998-1021: Similar complexity in fused gate/up weight loading.This method has similar complexity to the QKV method and would benefit from the same refactoring approach mentioned above.
The preprocessing logic is duplicated here and could use the same helper method suggested for the QKV case.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
cpp/tensorrt_llm/thop/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp(1 hunks)cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(1 hunks)tensorrt_llm/_torch/modules/linear.py(5 hunks)tensorrt_llm/quantization/functional.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_gemm.py(1 hunks)tests/unittest/_torch/thop/test_weight_only_quant_linear.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/tensorrt_llm/thop/CMakeLists.txt
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/quantization/functional.py
- tests/unittest/_torch/thop/test_weight_only_quant_linear.py
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.h
- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
- cpp/tensorrt_llm/thop/weightOnlyQuantGemm.cpp
- tests/unittest/_torch/thop/test_weight_only_quant_gemm.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
120-120: Line too long (135 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/modules/linear.py (5)
180-199: Well-designed helper function.This utility function effectively extracts weight dtype and packing logic into a reusable, well-documented component. The input validation, clear return value documentation, and error handling are excellent.
907-929: Clean implementation of weight creation and initialization.The
create_weightsmethod properly uses the helper function and follows established patterns for parameter creation and bias handling.
930-941: Verify bias handling in apply method.The apply method looks correct but has a potential issue with bias handling.
The bias is made contiguous unconditionally on line 934, but there's no guarantee that
biasis not None. While this works becauseNone.contiguous()would fail before the ternary operation, it's more explicit to check for None first:- bias = bias.contiguous() if bias is not None else None + bias = bias.contiguous() if bias is not None else NoneActually, the current code is correct - my concern was unfounded. The method looks good.
1419-1421: Clean addition to quantization method factory.The new conditional logic properly identifies weight-only quantization without per-group scaling and returns the appropriate method instance. The placement and logic are correct.
1546-1550: Property consolidation addresses previous review feedback.This new property provides a clean abstraction for weight-only quantization checks and addresses the consolidation concern raised in previous reviews. The implementation follows established patterns perfectly.
|
/bot run |
|
PR_Github #12409 [ run ] triggered by Bot |
|
PR_Github #12409 [ run ] completed with state |
…low (NVIDIA#5850) Signed-off-by: Yuening Li <[email protected]> Co-authored-by: Yuening Li <[email protected]>
…low (NVIDIA#5850) Signed-off-by: Yuening Li <[email protected]> Co-authored-by: Yuening Li <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
…low (NVIDIA#5850) Signed-off-by: Yuening Li <[email protected]> Co-authored-by: Yuening Li <[email protected]> Signed-off-by: Ransiki Zhang <[email protected]>
Add support of both INT4 and INT8 weight-only-quantization in PyTorch workflow.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Bug Fixes
Tests