-
Notifications
You must be signed in to change notification settings - Fork 345
Int4 sparse marlin tensor #2771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+341
−0
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a260dc8
added marlin sparse to packing format, inital commit
liangel-02 505e21f
deleting unnecessary functions
liangel-02 ac3e430
packing
liangel-02 a8bfed3
linear
liangel-02 b51b091
add call to from_hp
liangel-02 641cc71
unit test
liangel-02 ae14aa9
fix test_linear
liangel-02 cbd1bae
formatting
liangel-02 9f2ae7c
remove comments
liangel-02 30b23f3
update VERSION to version
liangel-02 dffd0e0
fix module path unit test
liangel-02 45e8a9e
adding sizes to linear unit test
liangel-02 ebbd3ab
move pre_process and from_plain to from_hp
liangel-02 7de0b12
compile test
liangel-02 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
107 changes: 107 additions & 0 deletions
107
test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
|
||
from torchao.quantization import ( | ||
Int4WeightOnlyConfig, | ||
quantize_, | ||
) | ||
from torchao.quantization.utils import compute_error | ||
from torchao.sparsity.sparse_api import apply_fake_sparsity | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_8, | ||
) | ||
|
||
BF16_ACT_CONFIG = Int4WeightOnlyConfig( | ||
group_size=128, | ||
packing_format="marlin_sparse", | ||
version=2, | ||
) | ||
|
||
|
||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
class TestInt4MarlinSparseTensor(TestCase): | ||
def setUp(self): | ||
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] | ||
|
||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
@parametrize( | ||
"sizes", | ||
[ | ||
((128,), 256, 128), | ||
((32, 128), 512, 128), | ||
((2, 32, 128), 256, 12), | ||
], | ||
) | ||
def test_linear(self, config, sizes): | ||
dtype = torch.float16 | ||
device = "cuda" | ||
|
||
M, N, K = sizes | ||
input = torch.randn(*M, K, dtype=dtype, device=device) | ||
linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
|
||
apply_fake_sparsity(linear) | ||
original = linear(input) | ||
quantize_(linear, config) | ||
quantized = linear(input) | ||
self.assertTrue(compute_error(original, quantized) > 20) | ||
|
||
compiled_linear = torch.compile(linear) | ||
quantized_and_compiled = compiled_linear(input) | ||
self.assertTrue(compute_error(original, quantized_and_compiled) > 20) | ||
|
||
@unittest.skip("Fix later") | ||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
def test_to_device(self, config): | ||
for device in self.GPU_DEVICES: | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device) | ||
|
||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device=device) | ||
|
||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device) | ||
|
||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
def test_module_path(self, config): | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear.cuda(), config) | ||
self.assertEqual( | ||
str(type(linear.weight)), | ||
"<class 'torchao.quantization.Int4MarlinSparseTensor'>", | ||
) | ||
|
||
with tempfile.NamedTemporaryFile() as f: | ||
torch.save(linear.state_dict(), f) | ||
f.seek(0) | ||
state_dict = torch.load(f) | ||
self.assertEqual( | ||
str(type(state_dict["weight"])), | ||
"<class 'torchao.quantization.Int4MarlinSparseTensor'>", | ||
) | ||
|
||
|
||
instantiate_parametrized_tests(TestInt4MarlinSparseTensor) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
216 changes: 216 additions & 0 deletions
216
torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import List | ||
|
||
import torch | ||
|
||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
choose_qparams_affine, | ||
quantize_affine, | ||
) | ||
from torchao.utils import TorchAOBaseTensor | ||
|
||
__all__ = [ | ||
"Int4MarlinSparseTensor", | ||
] | ||
|
||
aten = torch.ops.aten | ||
|
||
|
||
class Int4MarlinSparseTensor(TorchAOBaseTensor): | ||
tensor_data_names = ["qdata", "scale", "zero_point", "meta"] | ||
tensor_attribute_names = ["block_size", "num_bits", "shape"] | ||
|
||
def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): | ||
kwargs = {} | ||
kwargs["device"] = qdata.device | ||
kwargs["dtype"] = scale.dtype | ||
kwargs["requires_grad"] = False | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): | ||
self.qdata = qdata | ||
self.scale = scale | ||
self.zero_point = zero_point | ||
self.meta = meta | ||
self.block_size = block_size | ||
self.num_bits = num_bits | ||
|
||
def _quantization_type(self): | ||
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" | ||
|
||
@classmethod | ||
def from_hp( | ||
cls, | ||
w: torch.Tensor, | ||
block_size: List[int], | ||
): | ||
from torchao.sparsity.marlin import ( | ||
const, | ||
inject_24, # avoid circular import | ||
pack_to_marlin_24, | ||
) | ||
|
||
"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. | ||
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format | ||
- 2º: tensor is injected with 2:4 sparsity | ||
- 3º: transposes it again because the quantization process will compute the scales for dim=-1 | ||
""" | ||
|
||
w_t = w.t() | ||
w_24, _ = inject_24(w_t, *w_t.shape) | ||
preprocessed_w = w_24.t() | ||
|
||
assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( | ||
f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" | ||
) | ||
|
||
quant_min = 0 | ||
quant_max = 15 | ||
target_dtype = torch.int32 | ||
|
||
scale, zero_point = choose_qparams_affine( | ||
input=preprocessed_w, | ||
mapping_type=MappingType.SYMMETRIC, | ||
block_size=block_size, | ||
target_dtype=target_dtype, | ||
quant_min=quant_min, | ||
quant_max=quant_max, | ||
eps=1e-6, | ||
) | ||
|
||
wq = quantize_affine( | ||
input=preprocessed_w, | ||
block_size=block_size, | ||
scale=scale, | ||
zero_point=zero_point, | ||
output_dtype=target_dtype, | ||
quant_min=quant_min, | ||
quant_max=quant_max, | ||
) | ||
|
||
scale = scale.to(w.dtype) | ||
zero_point = zero_point.to(w.dtype) | ||
|
||
# Linear layers are (in_features, out_features) but the qdata that is reaching this point | ||
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. | ||
q_w_24 = wq.t() | ||
# addressing the case when scale has dimension 1, happens when | ||
# weight_shape[-1] == group_size == 128 | ||
if scale.ndim == 1: | ||
scale = scale.reshape(scale.shape[0], -1) | ||
|
||
scale_t = scale.t() | ||
|
||
if not torch.cuda.get_device_capability()[0] >= 8: | ||
raise ValueError( | ||
f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." | ||
) | ||
|
||
if q_w_24.dtype != torch.int32: | ||
raise ValueError("Only `torch.int32` weights are supported.") | ||
|
||
in_features, out_features = q_w_24.shape | ||
if in_features % 128 != 0 or out_features != 256 == 0: | ||
raise ValueError( | ||
"`in_features` must be divisible by 64 and `out_features` by 256." | ||
) | ||
|
||
# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 | ||
# will require a bit more work to get our current quantization flow to work with it. | ||
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main | ||
num_bits = 4 if torch.max(q_w_24) < 16 else -1 | ||
if num_bits not in [4]: | ||
raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") | ||
|
||
group_size = in_features // scale_t.shape[0] | ||
if group_size == 0: | ||
group_size = in_features | ||
assert group_size <= in_features, ( | ||
"Group size must be less than or equal to in_features." | ||
) | ||
|
||
if group_size not in const.SUPPORTED_GROUP_SIZES: | ||
raise ValueError( | ||
f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." | ||
) | ||
|
||
# Compress quantized weight to marlin 2:4 format | ||
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( | ||
q_w_24, scale_t, num_bits, group_size | ||
) | ||
|
||
return cls( | ||
qdata=marlin_24_q_w_comp, | ||
scale=marlin_24_s, | ||
zero_point=zero_point, | ||
meta=meta, | ||
block_size=group_size, | ||
shape=q_w_24.shape, | ||
num_bits=num_bits, | ||
) | ||
|
||
|
||
implements = Int4MarlinSparseTensor.implements | ||
|
||
|
||
@implements([torch.nn.functional.linear, aten.linear.default]) | ||
def _(func, types, args, kwargs): | ||
from torchao.ops import marlin_24_gemm | ||
from torchao.sparsity.marlin import marlin_24_workspace | ||
|
||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" | ||
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" | ||
assert weight_tensor.zero_point.is_contiguous(), ( | ||
"Expected zero_point to be contiguous" | ||
) | ||
|
||
sparse_w_int4 = weight_tensor.qdata | ||
scale = weight_tensor.scale | ||
meta = weight_tensor.meta | ||
original_shape = weight_tensor.shape | ||
num_bits = weight_tensor.num_bits | ||
|
||
# Folds batch dimension into the first dimension | ||
input_2d = input_tensor.view(-1, input_tensor.shape[-1]) | ||
|
||
size_m = input_2d.shape[0] | ||
size_n = scale.shape[1] | ||
size_k = input_2d.shape[1] | ||
workspace_24 = marlin_24_workspace(original_shape[1]) | ||
|
||
out = marlin_24_gemm( | ||
input_2d, | ||
sparse_w_int4, | ||
meta, | ||
scale, | ||
workspace_24, | ||
num_bits, | ||
size_m, | ||
size_n, | ||
size_k, | ||
) | ||
|
||
# Unfold the batch dimension | ||
out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) | ||
|
||
if bias is not None: | ||
out += bias.to(out.dtype) | ||
return out | ||
liangel-02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
Int4MarlinSparseTensor.__module__ = "torchao.quantization" | ||
|
||
# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True` | ||
torch.serialization.add_safe_globals([Int4MarlinSparseTensor]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.