-
Notifications
You must be signed in to change notification settings - Fork 1.8k
W4A8 GEMM #6005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
W4A8 GEMM #6005
Conversation
Signed-off-by: Daniel Afrimi <[email protected]>
/bot run |
PR_Github #11818 [ run ] triggered by Bot |
PR_Github #11818 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work Daniel
Had some annoying comments - didnt find any bugs :)
Signed-off-by: Daniel Afrimi <[email protected]>
WalkthroughThe changes introduce a generalized fine-grained mixed-dtype GEMM runner and operator, replacing the previous W4A16-specific implementation. The new runner supports configurable activation and output data types, an alpha scaling parameter, and extended quantization methods including W4A8 AWQ. The Python API, C++ backend, and test suites are updated to reflect these enhancements, and new tests are added for W4A8 functionality. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant LinearModule
participant FinegrainedMixedDtypeGemm
participant CUDAOp
User->>LinearModule: Forward(input)
LinearModule->>FinegrainedMixedDtypeGemm: forward(input, weight, scales, group_size, ...)
FinegrainedMixedDtypeGemm->>CUDAOp: finegrained_mixed_dtype_gemm(input, weight, scales, group_size, output_dtype, alpha, ...)
CUDAOp-->>FinegrainedMixedDtypeGemm: Output Tensor
FinegrainedMixedDtypeGemm-->>LinearModule: Output Tensor
LinearModule-->>User: Output Tensor
Suggested reviewers
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
Signed-off-by: Daniel Afrimi <[email protected]>
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/unittest/_torch/thop/test_w4a8_linear.py (1)
28-29
: Good test coverage with asymmetric dimensions.The test now properly validates the case where
input_dim != output_dim
(128 != 512), addressing the previous concern.
🧹 Nitpick comments (4)
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (1)
3-3
: Consider using absolute imports for better maintainability.The relative import
from utils.util import
could break if the test file is moved. Consider using an absolute import path.-from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul +from tests.unittest._torch.thop.utils.util import woq_assert_near_eq, woq_groupwise_gt_matmultensorrt_llm/_torch/modules/linear.py (3)
1086-1086
: Consider breaking this long line for better readability.- or input.dtype, # NOTE: output_dtype can only be bf16/fp16 for W4A8 + or input.dtype, # output_dtype: bf16/fp16 only for W4A8
1190-1191
: Address line length for better readability.- # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale - # Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal) + # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer + # shared the same input and create an avg pre_quant_scale. + # Usually when modelopt exports the quantized model, pre_quant_scale is fused + # in the layer norm (this case relevant if fused is disabled - modelopt internal)
1237-1237
: Break this long comment line.- # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16) + # NOTE: Create this tensor in load_weights, since not all layers have this + # tensor and memory is not allocated for it (same as W4A16)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp
(3 hunks)cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h
(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
(3 hunks)tensorrt_llm/_torch/modules/linear.py
(8 hunks)tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py
(1 hunks)tests/unittest/_torch/thop/test_w4a16_gemm.py
(0 hunks)tests/unittest/_torch/thop/test_w4a16_linear.py
(3 hunks)tests/unittest/_torch/thop/test_w4a8_linear.py
(1 hunks)
💤 Files with no reviewable changes (1)
- tests/unittest/_torch/thop/test_w4a16_gemm.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
FinegrainedMixedDtypeGemm
(678-717)finegrained_mixed_dtype_gemm
(722-764)
tensorrt_llm/_torch/modules/linear.py (6)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (7)
finegrained_mixed_dtype_gemm
(722-764)_
(216-255)_
(334-342)_
(423-433)_
(605-632)_
(665-675)_
(884-971)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
create_weights
(165-173)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
create_weights
(296-304)tensorrt_llm/_torch/modules/fused_moe/quantization.py (9)
create_weights
(72-85)create_weights
(227-239)create_weights
(251-283)create_weights
(432-464)create_weights
(551-624)create_weights
(810-865)create_weights
(1065-1070)create_weights
(1173-1185)apply
(181-188)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
create_weights
(139-147)tensorrt_llm/quantization/mode.py (2)
is_int4_weight_only_per_group
(129-130)QuantAlgo
(23-44)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
114-114: Line too long (135 > 120)
(E501)
1098-1098: Line too long (132 > 120)
(E501)
1190-1190: Line too long (143 > 120)
(E501)
1191-1191: Line too long (165 > 120)
(E501)
1237-1237: Line too long (142 > 120)
(E501)
🔇 Additional comments (13)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)
6-85
: Changes look good!The migration from
W4A16GemmRunner
toFinegrainedMixedDtypeGemm
is implemented correctly with proper keyword arguments in the operator call.cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1)
27-44
: Header changes are well-structured.The addition of
outputDtype
parameter andalpha
scaling factor with a sensible default value properly extends the interface for mixed dtype support.cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (2)
47-143
: Constructor implementation properly handles mixed dtype combinations.The validation logic and error messages are clear. Good support for Float8_e4m3fn activation with multiple output dtypes.
217-261
: Output dtype handling and alpha parameter passing are correct.The implementation properly uses
mOutputDtype
to determine the tensor output type and correctly handles the alpha scaling parameter.tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
678-691
: LGTM! Clean refactoring to support mixed dtype GEMM.The class has been properly renamed and extended to support configurable activation and output data types alongside the quantization mode. The caching mechanism is correctly implemented.
720-764
: Well-structured custom operator with proper validation.The operator correctly validates the presence of zero point tensors based on the quantization mode and includes all necessary parameters for mixed dtype GEMM operations.
tensorrt_llm/_torch/modules/linear.py (7)
113-119
: LGTM! Correct handling of both W4A16 and W4A8 AWQ methods.The logic properly determines the activation dtype based on the quantization method and applies the necessary preprocessing for both AWQ variants.
916-924
: Good migration to the generalized GEMM operator.The change from
w4a16_gemm
tofinegrained_mixed_dtype_gemm
properly supports the unified mixed-dtype GEMM interface with appropriate parameters includingoutput_dtype
.
969-969
: Note the consistent transpose pattern for weight scales.The weight scale tensors are now transposed and made contiguous across all loading methods. This aligns with the expected layout for the new
finegrained_mixed_dtype_gemm
operator.Also applies to: 986-987, 1008-1009
1054-1067
: Well-documented quantization flow.The docstring clearly explains the modelopt flow for w4a8_awq, which is helpful for understanding the implementation. The flow correctly handles pre_quant_scale multiplication, FP8 quantization, and output rescaling.
1071-1075
: Good handling of pre-quantized FP8 inputs.The code efficiently handles the case where the input is already in FP8 format, avoiding redundant quantization operations.
1262-1264
: Correct integration of W4A8 AWQ support.The new quantization method is properly integrated into the factory function following the established pattern.
1389-1394
: LGTM! Property correctly identifies W4A8 AWQ quantization.The implementation follows the established pattern for quantization method properties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/linear.py (2)
1026-1026
: Fix parameter reference in error message.The error message should use the
in_features
parameter instead ofmodule.in_features
for consistency.- f"in_features ({module.in_features}) must be divisible by group_size ({group_size}) " + f"in_features ({in_features}) must be divisible by group_size ({group_size}) "
114-114
: Fix line length violations.Several lines exceed the 120-character limit. Consider breaking these lines for better readability.
Examples of how to fix:
- if module.has_w4a16_awq or module.has_w4a8_awq: + if (module.has_w4a16_awq or module.has_w4a8_awq):- def load_weight_scales_w4a8(self, + def load_weight_scales_w4a8( + self, weights: List[Dict], tp_size: int = 1, tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None): + tp_mode: Optional[TensorParallelMode] = None + ):Also applies to: 1093-1093, 1185-1186, 1232-1232
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/linear.py
(8 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
114-114: Line too long (135 > 120)
(E501)
1093-1093: Line too long (132 > 120)
(E501)
1185-1185: Line too long (143 > 120)
(E501)
1186-1186: Line too long (165 > 120)
(E501)
1232-1232: Line too long (142 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/modules/linear.py (5)
113-120
: LGTM! Clean extension to support W4A8 AWQ.The logic correctly determines the activation dtype based on the quantization method (float16 for W4A16, float8_e4m3fn for W4A8) and uses it appropriately in the weight preprocessing step.
916-924
: LGTM! Successful migration to generalized GEMM operator.The migration from the specialized W4A16 operator to
finegrained_mixed_dtype_gemm
is well-implemented. The parameter mapping is correct and the absence of the alpha parameter is appropriate for W4A16 AWQ.
1012-1238
: Excellent implementation of W4A8 AWQ support.The new
W4A8_AWQ_LinearMethod
class is well-structured and follows established patterns while properly handling FP8-specific requirements. The implementation includes:
- Proper FP8 quantization with input_scale and alpha parameters
- Correct weight scale handling with float16 dtype
- Comprehensive scale computation logic
- Clear documentation of the modelopt flow
- Consistent error handling and validation
The class mirrors the W4A16 implementation while adding the necessary FP8-specific adaptations.
1257-1259
: LGTM! Proper method selection logic.The addition of W4A8 AWQ method selection follows the established pattern and correctly checks for the appropriate quantization algorithm.
1384-1388
: LGTM! Consistent property addition.The
has_w4a8_awq
property follows the same pattern as other quantization method properties and provides the necessary interface for detecting W4A8 AWQ quantization.
PR_Github #12210 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR_Github #12210 [ run ] completed with state |
/bot run |
PR_Github #12374 [ run ] triggered by Bot |
/bot kill |
PR_Github #12375 [ kill ] triggered by Bot |
PR_Github #12374 [ run ] completed with state |
PR_Github #12375 [ kill ] completed with state |
Signed-off-by: Daniel Afrimi <[email protected]>
95ca1ea
to
213f5df
Compare
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)
120-120
: Address line length violations for better code style.Several lines exceed the 120-character limit. Consider breaking long lines for better readability:
- if module.has_w4a16_awq or module.has_w4a8_awq: + if module.has_w4a16_awq or module.has_w4a8_awq:- def load_weight_scales_w4a8(self, - weights: List[Dict], + def load_weight_scales_w4a8( + self, + weights: List[Dict],Also applies to: 1106-1106, 1203-1204, 1257-1257
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/linear.py
(10 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
120-120: Line too long (135 > 120)
(E501)
1106-1106: Line too long (132 > 120)
(E501)
1203-1203: Line too long (143 > 120)
(E501)
1204-1204: Line too long (165 > 120)
(E501)
1257-1257: Line too long (142 > 120)
(E501)
🔇 Additional comments (15)
tensorrt_llm/_torch/modules/linear.py (15)
50-54
: Well-implemented helper method.The
flip
method provides a clean way to get the orthogonal tensor parallel mode, which is needed for proper sharding of activation scales. The implementation is correct and the comment clearly explains its purpose.
119-125
: Proper extension for W4A8 support.The changes correctly extend the weight preprocessing logic to handle both W4A16 and W4A8 quantization methods by selecting the appropriate activation dtype (
torch.float16
for W4A16 vstorch.float8_e4m3fn
for W4A8).
902-902
: Improved flexibility with dtype parameter.Using
dtype
instead of hardcodedtorch.float16
makes the weight scale creation more flexible and consistent with the module's data type.
918-918
: Proper conditional application of pre-quantization scaling.Moving the pre-quantization scaling inside the condition ensures it's only applied when the scale tensor exists.
922-930
: Successfully migrated to generalized GEMM operator.The transition from the specialized W4A16 GEMM to the generic
finegrained_mixed_dtype_gemm
operator is well-implemented, with proper parameter mapping and the addition of required parameters likeoutput_dtype
andalpha
.
964-982
: Correct pre-quantization scale handling.The use of
TensorParallelMode.flip(module.tp_mode)
correctly handles the sharding of activation scales along the orthogonal dimension to the weights. The tensor creation and copying logic is also properly implemented.
999-999
: Performance optimization with contiguous memory layout.Adding
.T.contiguous()
ensures optimal memory layout for the concatenated weight scales, which can improve performance in subsequent operations.Also applies to: 1021-1022
1027-1063
: Well-structured weight creation for W4A8 AWQ.The
create_weights
method properly sets up all required parameters for W4A8 quantization:
- Quantized weights with correct shape and dtype
- Float16 weight scales (appropriate for FP8 activations)
- Input scaling parameters for FP8 quantization
- Alpha parameter for output rescaling
The implementation follows established patterns and includes proper validation.
1065-1099
: Excellent implementation with clear documentation.The
apply
method is well-documented with a detailed docstring explaining the ModelOpt quantization flow. The implementation correctly:
- Handles conditional pre-quantization scaling
- Supports both pre-quantized FP8 inputs and dynamic quantization
- Uses the generalized GEMM operator with proper parameter mapping
- Includes appropriate output dtype constraints
1101-1138
: Consistent and correct scale loading implementation.The
load_weight_scales_w4a8
method follows established patterns while properly handling W4A8-specific requirements:
- Extracts all necessary scaling factors
- Converts weight scales to appropriate float16 dtype
- Includes proper consistency checks with assertions
- Computes alpha correctly for output rescaling
1140-1175
: Proper vanilla weight loading for W4A8.The implementation correctly follows the W4A8 quantization requirements:
- Uses appropriate helper functions for weight loading
- Handles pre-quantization scales with correct tensor parallel sharding
- Applies weight_scale_2 division as required by the quantization scheme
- Sets up all scaling parameters properly
1177-1220
: Comprehensive fused QKV weight loading.The implementation properly handles fused QKV weights for W4A8:
- Uses appropriate preprocessing with FP8 activation dtype
- Correctly concatenates and processes weight scales with division
- Includes optional pre-quantization scale handling with proper comments
- Follows established patterns for fused weight operations
1222-1262
: Consistent fused gate-up weight loading implementation.The method correctly implements fused gate-up weight loading for W4A8:
- Proper weight preprocessing and concatenation
- Correct scale processing with weight_scale_2 division
- Optional pre-quantization scale handling consistent with QKV implementation
- Good code comments explaining behavior
1282-1284
: Proper integration of W4A8 AWQ method selection.The addition correctly extends the quantization method selection logic to handle W4A8 AWQ, following the established pattern for other quantization methods.
1409-1413
: Consistent property implementation for W4A8 AWQ detection.The
has_w4a8_awq
property follows the exact same pattern ashas_w4a16_awq
, providing consistent interface for detecting W4A8 AWQ quantization.
PR_Github #12376 [ run ] triggered by Bot |
PR_Github #12376 [ run ] completed with state |
/bot run |
PR_Github #12377 [ run ] triggered by Bot |
PR_Github #12377 [ run ] completed with state |
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/linear.py (5)
119-125
: LGTM! Correctly handles both AWQ quantization modes.The logic properly differentiates between W4A16 and W4A8 AWQ modes and sets the appropriate activation dtype. The addition of
.contiguous()
addresses previous performance feedback.Minor issue: Line 120 exceeds the 120-character limit. Consider breaking it into multiple lines for better readability.
1067-1101
: Comprehensive apply method with clear documentation.The method correctly implements the modelopt W4A8 AWQ flow:
- Pre-quantization scaling
- FP8 input quantization with proper fallback
- Mixed-dtype GEMM with alpha scaling
The docstring provides excellent clarity on the quantization pipeline.
Minor issue: Line length violations on lines 1096-1097. Consider breaking the comment for better readability.
1103-1140
: Robust weight scale loading implementation.The
load_weight_scales_w4a8
method properly handles:
- Shared scaling factors for concatenated weights
- Proper tensor parallel sharding
- Correct dtype conversions to float16
- Alpha computation for GEMM output rescaling
The assertions ensure consistency across multiple weights.
Minor issue: Line 1108 exceeds character limit. Consider breaking the method signature.
1179-1223
: Comprehensive fused QKV weight loading.The method handles the complexities of fused QKV layers:
- Proper weight preprocessing for mixed GEMM
- Correct concatenation and scaling of weight scales
- Conditional pre_quant_scale loading based on modelopt export settings
- Appropriate tensor parallel mode flipping for activation scales
The comments explain the modelopt behavior regarding pre_quant_scale fusion.
Minor issues: Lines 1205-1206 exceed character limits.
1224-1264
: Consistent fused gate-up weight loading.The implementation mirrors the QKV approach with:
- Proper weight preprocessing and concatenation
- Correct scale handling and alpha computation
- Conditional pre_quant_scale loading
- Consistent memory layout optimization
Minor issue: Line 1259 exceeds character limit.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/linear.py
(10 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
120-120: Line too long (135 > 120)
(E501)
1108-1108: Line too long (132 > 120)
(E501)
1205-1205: Line too long (143 > 120)
(E501)
1206-1206: Line too long (165 > 120)
(E501)
1259-1259: Line too long (142 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (11)
tensorrt_llm/_torch/modules/linear.py (11)
50-54
: LGTM! Well-documented helper method.The
flip
method correctly implements the orthogonal tensor parallel mode logic needed for activation scale sharding. The docstring clearly explains its purpose.
904-904
: Good generalization of dtype handling.Removing the explicit
dtype
parameter makes the method more flexible and consistent with the generalized approach.
920-920
: Performance optimization applied correctly.The contiguous() call addresses previous review feedback about memory layout optimization during preprocessing.
924-932
: LGTM! Updated to use the generalized mixed-dtype GEMM.The function call correctly migrates from the W4A16-specific interface to the new
finegrained_mixed_dtype_gemm
with appropriate parameters includingoutput_dtype
and proper argument mapping.
965-984
: Correct tensor parallel handling for activation scales.The use of
TensorParallelMode.flip(module.tp_mode)
is appropriate sincepre_quant_scale
applies to activations rather than weights, requiring the orthogonal sharding dimension.
1001-1001
: Memory layout optimization applied.The
.T.contiguous()
call ensures optimal memory layout for the concatenated weight scale tensor.
1023-1023
: Consistent memory layout optimization.The
.T.contiguous()
call maintains consistency with other weight scale handling and optimizes memory layout.
1027-1066
: Well-structured weight creation for W4A8 AWQ.The
create_weights
method properly handles the FP8 activation requirements:
- Uses
float16
for weight scales (required for FP8 activation)- Creates necessary parameters for input scaling and alpha
- Proper error handling for group_size divisibility
The implementation follows established patterns from W4A16 AWQ.
1142-1177
: Correct vanilla weight loading with proper tensor parallel handling.The implementation correctly:
- Loads pre_quant_scale with flipped tensor parallel mode (for activation scaling)
- Handles weight scale division by weight_scale_2
- Maintains proper tensor shapes with transpose and contiguous operations
- Sets inverse input scale appropriately
The dtype assertion ensures compatibility between pre_quant_scale and module dtype.
1284-1286
: LGTM! Consistent pattern for W4A8 AWQ method selection.The addition correctly identifies W4A8 AWQ configurations and returns the appropriate quantization method, following the established pattern from W4A16 AWQ.
1411-1415
: LGTM! Consistent property for W4A8 AWQ detection.The property correctly identifies W4A8 AWQ quantization mode, following the same pattern as
has_w4a16_awq
. This provides a clean interface for checking the quantization configuration.
PR_Github #12381 [ run ] triggered by Bot |
PR_Github #12381 [ run ] completed with state |
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]> Signed-off-by: Ransiki Zhang <[email protected]>
W4A8 GEMM
Support running w4a8_awq quantized model from modelopt.
Kernel support multiple gemm with mixed dtype
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests