Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
)

BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="marlin_sparse",
version=2,
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestInt4MarlinSparseTensor(TestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

@parametrize("config", [BF16_ACT_CONFIG])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 12),
],
)
def test_linear(self, config, sizes):
dtype = torch.float16
device = "cuda"

M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)

apply_fake_sparsity(linear)
original = linear(input)
quantize_(linear, config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@unittest.skip("Fix later")
@parametrize("config", [BF16_ACT_CONFIG])
def test_to_device(self, config):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device)

@parametrize("config", [BF16_ACT_CONFIG])
def test_module_path(self, config):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear.cuda(), config)
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
)


instantiate_parametrized_tests(TestInt4MarlinSparseTensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
)
from .quantize_.workflows import (
Float8Tensor,
Int4MarlinSparseTensor,
Int4PreshuffledTensor,
Int4Tensor,
)
Expand Down Expand Up @@ -159,6 +160,7 @@
# tensor subclasses
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Float8Tensor",
# smooth quant - subject to change
"get_scale",
Expand Down
7 changes: 7 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from torchao.quantization.quantize_.workflows import (
Float8Tensor,
Int4MarlinSparseTensor,
Int4PreshuffledTensor,
Int4Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -1068,6 +1069,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size,
)
return new_weight
elif packing_format == PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
weight,
block_size,
)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")

Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ class PackingFormat(str, Enum):
preshuffled is referring to the preshuffled format used by fbgemm kernels
"""
PRESHUFFLED = "preshuffled"

"""
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
"""
MARLIN_SPARSE = "marlin_sparse"
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
from .int4.int4_marlin_sparse_tensor import (
Int4MarlinSparseTensor,
)
from .int4.int4_preshuffled_tensor import (
Int4PreshuffledTensor,
)
Expand All @@ -12,6 +15,7 @@
__all__ = [
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch

from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
quantize_affine,
)
from torchao.utils import TorchAOBaseTensor

__all__ = [
"Int4MarlinSparseTensor",
]

aten = torch.ops.aten


class Int4MarlinSparseTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "scale", "zero_point", "meta"]
tensor_attribute_names = ["block_size", "num_bits", "shape"]

def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.meta = meta
self.block_size = block_size
self.num_bits = num_bits

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
from torchao.sparsity.marlin import (
const,
inject_24, # avoid circular import
pack_to_marlin_24,
)

"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel.
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format
- 2º: tensor is injected with 2:4 sparsity
- 3º: transposes it again because the quantization process will compute the scales for dim=-1
"""

w_t = w.t()
w_24, _ = inject_24(w_t, *w_t.shape)
preprocessed_w = w_24.t()

assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], (
f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}"
)

quant_min = 0
quant_max = 15
target_dtype = torch.int32

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)

wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)

scale = scale.to(w.dtype)
zero_point = zero_point.to(w.dtype)

# Linear layers are (in_features, out_features) but the qdata that is reaching this point
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
q_w_24 = wq.t()
# addressing the case when scale has dimension 1, happens when
# weight_shape[-1] == group_size == 128
if scale.ndim == 1:
scale = scale.reshape(scale.shape[0], -1)

scale_t = scale.t()

if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(
f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel."
)

if q_w_24.dtype != torch.int32:
raise ValueError("Only `torch.int32` weights are supported.")

in_features, out_features = q_w_24.shape
if in_features % 128 != 0 or out_features != 256 == 0:
raise ValueError(
"`in_features` must be divisible by 64 and `out_features` by 256."
)

# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
# will require a bit more work to get our current quantization flow to work with it.
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
num_bits = 4 if torch.max(q_w_24) < 16 else -1
if num_bits not in [4]:
raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.")

group_size = in_features // scale_t.shape[0]
if group_size == 0:
group_size = in_features
assert group_size <= in_features, (
"Group size must be less than or equal to in_features."
)

if group_size not in const.SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}."
)

# Compress quantized weight to marlin 2:4 format
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(
q_w_24, scale_t, num_bits, group_size
)

return cls(
qdata=marlin_24_q_w_comp,
scale=marlin_24_s,
zero_point=zero_point,
meta=meta,
block_size=group_size,
shape=q_w_24.shape,
num_bits=num_bits,
)


implements = Int4MarlinSparseTensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
from torchao.ops import marlin_24_gemm
from torchao.sparsity.marlin import marlin_24_workspace

input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous"
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous"
assert weight_tensor.zero_point.is_contiguous(), (
"Expected zero_point to be contiguous"
)

sparse_w_int4 = weight_tensor.qdata
scale = weight_tensor.scale
meta = weight_tensor.meta
original_shape = weight_tensor.shape
num_bits = weight_tensor.num_bits

# Folds batch dimension into the first dimension
input_2d = input_tensor.view(-1, input_tensor.shape[-1])

size_m = input_2d.shape[0]
size_n = scale.shape[1]
size_k = input_2d.shape[1]
workspace_24 = marlin_24_workspace(original_shape[1])

out = marlin_24_gemm(
input_2d,
sparse_w_int4,
meta,
scale,
workspace_24,
num_bits,
size_m,
size_n,
size_k,
)

# Unfold the batch dimension
out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],))

if bias is not None:
out += bias.to(out.dtype)
return out


Int4MarlinSparseTensor.__module__ = "torchao.quantization"

# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4MarlinSparseTensor])
Loading