Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import copy
import itertools
import os
import sys

import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
GROUPSIZES = [32, 64, 128, 256]

# Currently, the quantization code in quant_api.py only supports K values
# multiple of group_size.
# TODO(mcandales): Generalize the code in quant_api.py and add tests to
# cover values of K not multiple of group_size.
def _model_setup(self):
group_size = 32
k0 = 96
k1 = 224
k2 = 160
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, n, bias=False),
]
model = torch.nn.Sequential(*layers)
return model, group_size, k0, n

def _quantize_model(self, model, precision, nbit, group_size):
quantizer = UIntxWeightOnlyLinearQuantizer(
device="mps",
precision=precision,
bitwidth=nbit,
groupsize=group_size,
)
quantized_model = copy.deepcopy(model)
quantized_model = quantizer.quantize(quantized_model)
return quantized_model

@parameterized.expand(BITWIDTHS)
def test_export(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
exported = torch.export.export(quantized_model, (activations,))

for node in exported.graph.nodes:
if node.op == "call_function":
self.assertTrue(
str(node.target)
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
)

@parameterized.expand(BITWIDTHS)
def test_2d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, n))

@parameterized.expand(BITWIDTHS)
def test_3d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
leading_shape = (3, 5)
activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (*leading_shape, n))

@parameterized.expand(itertools.product(BITWIDTHS, GROUPSIZES))
def test_valid_groupsizes(self, nbit, group_size):
k0 = 3 * group_size
k1 = 7 * group_size
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, n, bias=False),
]
model = torch.nn.Sequential(*layers)
m = 5
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, n))

@parameterized.expand(BITWIDTHS)
def test_invalid_groupsizes(self, nbit):
group_size = 16
k0 = 3 * group_size
k1 = 7 * group_size
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, n, bias=False),
]
model = torch.nn.Sequential(*layers)

with self.assertRaises(ValueError):
self._quantize_model(model, torch.float32, nbit, group_size)

# TODO(mcandales): Consolidate with the reference impl in test_lowbit.py
def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z):
N = W.shape[0]
K = W.shape[1]
W = W.to(torch.float32)
scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized.expand(BITWIDTHS)
def test_accuracy(self, nbit):
group_size = 32
m = 3
n = 7
k = 64
with torch.no_grad():
activations = torch.rand(m, k, dtype=torch.float32, device="mps")
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
quantized_model = self._quantize_model(
model, torch.float32, nbit, group_size
)
result = quantized_model(activations)

# Compute expected result
weight_cpu = model[0].weight.data
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize(
weight_cpu, group_size, nbit, True, torch.uint8
)
weight_scales_cpu = weight_scales_cpu.t()
weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu
expected = self._reference_linear_lowbit_quant_weights(
activations.cpu(),
weight_qvals_cpu,
group_size,
weight_scales_cpu,
weight_zeros_cpu,
)

# Compare results
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001)


if __name__ == "__main__":
unittest.main()
130 changes: 126 additions & 4 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
logger.addHandler(handler)


def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool):
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True):
assert nbit >= 1 and nbit <= 8
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
if signed:
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
else:
qmin = 0
qmax = (1 << nbit) - 1

n, k = vals.shape
vals = vals.reshape(-1, group_size)
Expand All @@ -51,7 +55,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros:
zero_points=group_zeros,
quant_min=qmin,
quant_max=qmax,
dtype=torch.int8,
dtype=torch.int8 if signed else torch.uint8,
group_size=group_size,
)

Expand Down Expand Up @@ -516,3 +520,121 @@ def apply(weight):
)

return _get_linear_subclass_inserter(apply)


class UIntxWeightOnlyQuantizedLinear(nn.Module):
def __init__(
self,
pack_weight_op,
linear_op,
):
super().__init__()
self._pack_weights_op = pack_weight_op
self._linear_op = linear_op

def quantize_and_pack_weights(self, weights, nbit, group_size):
self.nbit = nbit
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, has_weight_zeros=True, signed=False
)
weight_scales = torch.transpose_copy(weight_scales, 1, 0)
weight_zeros = torch.transpose_copy(weight_zeros, 1, 0)
self.weight_scales = weight_scales
self.weight_zeros = -weight_zeros * weight_scales

self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps")

def forward(self, x):
assert x.dim() >= 2
if x.dim() == 2:
return self._linear_op(
x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
)

lead_shape = x.shape[0:-1]
k = x.shape[-1]
n = self.weight_scales.shape[1]
return self._linear_op(
x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
).reshape(*lead_shape, n)

# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear
def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
group_size = kwargs["group_size"]
nbit = kwargs["nbit"]

assert not isinstance(module, nn.Linear)
assert nbit >= 1 and nbit <= 7

for name, child in module.named_children():
if not isinstance(child, nn.Linear):
_replace_linear_with_quantized_linear_mps(child, kwargs)
else:
assert child.bias is None
qlinear = UIntxWeightOnlyQuantizedLinear(
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
linear_op=getattr(
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
),
)
setattr(module, name, qlinear)
qlinear.quantize_and_pack_weights(
child.weight, nbit, group_size
)


class UIntxWeightOnlyLinearQuantizer:
def __init__(
self,
device,
precision,
*,
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
if device != "mps":
raise NotImplementedError(
"Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.device = device

if precision not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(
"Only precisions float32, float16 & bfloat16 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.precision = precision

if bitwidth is None:
bitwidth = 4
logger.warning(f"bitwidth not specified, defaulting to {bitwidth}.")
if bitwidth not in range(1, 8):
raise ValueError(
"Only bitwidts 1 to 7 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.bitwidth = bitwidth

if groupsize is None:
groupsize = 128
logger.warning(f"groupsize not specified, defaulting to {groupsize}.")
if groupsize not in [32, 64, 128, 256]:
raise ValueError(
"Only groupsizes 32, 64, 128 & 256 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.groupsize = groupsize

def quantize(self, model: nn.Module) -> nn.Module:
model = model.to(self.device).to(self.precision)
_replace_linear_with_quantized_linear_mps(
model,
kwargs={
"group_size": self.groupsize,
"nbit": self.bitwidth,
},
)
return model
Loading