diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 68c25821ee..36765fb9b5 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -21,10 +21,12 @@ from torchao.prototype.parq.quant import ( Int4UnifTorchaoQuantizer, LSBQuantizer, + StretchedUnifTorchaoQuantizer, TernaryUnifQuantizer, UnifQuantizer, UnifTorchaoQuantizer, ) +from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import ( @@ -35,11 +37,11 @@ from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - MappingType, _is_linear, int4_weight_only, quantize_, ) +from torchao.quantization.quant_primitives import MappingType from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6, @@ -74,6 +76,59 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): ] +def compare_quantized_models( + model: nn.Module, + m_ref: nn.Module, + quantizer: UnifTorchaoQuantizer, + b: int, + group_size: int, +): + for n, module in model.named_children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + original_shape = p.shape + p = p.view(-1, group_size) + + q, Q = quantizer.quantize(p, b=b, dim=-1) + + # compare to AffineQuantizedTensor instance + q = q.view(original_shape) + ref = getattr(m_ref, n).weight.dequantize() + torch.testing.assert_close(q, ref, atol=0, rtol=0) + + +def compare_parq_convert( + model: nn.Module, + m_ref: nn.Module, + optimizer: QuantOptimizer, + config: AOBaseConfig, +): + # do not update model weights, just quantize + optimizer.zero_grad() + optimizer.step() + + orig_model = copy.deepcopy(model) # save copy of PARQ quantized model + + # equivalent to torchao's convert step + model.eval() + optimizer.restore_latent_params() + quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) + + for n, module in model.named_modules(): + if not _is_linear(module): + continue + + p_orig = getattr(orig_model, n).weight # PARQ weight + p = module.weight.dequantize() # PARQ weight after quantize_ + p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ + + torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) + torch.testing.assert_true(p, p_ref, atol=0, rtol=0) + + class M(nn.Module): def __init__(self, m=256, n=128, k=16, bias=False, embedding=True): super().__init__() @@ -143,59 +198,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - def compare_quantized_models( - self, - model: nn.Module, - m_ref: nn.Module, - quantizer: UnifTorchaoQuantizer, - b: int, - group_size: int, - ): - for n, module in model.named_children(): - if not _is_linear(module): - continue - - # simulate grouping from QuantOptimizer.step - p = module.weight - original_shape = p.shape - p = p.view(-1, group_size) - - q, Q = quantizer.quantize(p, b=b, dim=-1) - - # compare to AffineQuantizedTensor instance - q = q.view(original_shape) - ref = getattr(m_ref, n).weight.dequantize() - torch.testing.assert_close(q, ref, atol=0, rtol=0) - - def compare_parq_convert( - self, - model: nn.Module, - m_ref: nn.Module, - optimizer: QuantOptimizer, - config: AOBaseConfig, - ): - # do not update model weights, just quantize - optimizer.zero_grad() - optimizer.step() - - orig_model = copy.deepcopy(model) # save copy of PARQ quantized model - - # equivalent to torchao's convert step - model.eval() - optimizer.restore_latent_params() - quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) - - for n, module in model.named_modules(): - if not _is_linear(module): - continue - - p_orig = getattr(orig_model, n).weight # PARQ weight - p = module.weight.dequantize() # PARQ weight after quantize_ - p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ - - torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) - torch.testing.assert_true(p, p_ref, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): @@ -209,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32): quantize_(m_ref, config) b = 4 - self.compare_quantized_models( + compare_quantized_models( model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size ) @@ -229,7 +231,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): ) quantizer = UnifTorchaoQuantizer() - self.compare_quantized_models(model, m_ref, quantizer, b, group_size) + compare_quantized_models(model, m_ref, quantizer, b, group_size) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @@ -251,7 +253,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, config) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @@ -273,7 +275,84 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, config) + + +class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): + def setUp(self): + torch.manual_seed(123) + + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 256]) + def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer_ref = UnifQuantizer() + quantizer = StretchedUnifTorchaoQuantizer(b) + + for n, module in model.named_children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + p = p.view(-1, group_size) + + q_ref, Q_ref = quantizer_ref.quantize(p, b=b, dim=-1) + q, Q = quantizer.quantize(p, b=b, dim=-1) + + torch.testing.assert_close(q, q_ref, atol=0, rtol=0) + torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 512]) + def test_intx_weight_only(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + quantize_( + m_ref, + StretchedIntxWeightOnlyConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + ), + ) + + compare_quantized_models(model, m_ref, quantizer, b, group_size) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") + @common_utils.parametrize("b", [2, 3]) + def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + config = StretchedIntxWeightOnlyConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + ) + quantize_(m_ref, config) + + base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optimizer = QuantOptimizer( + base_optimizer, + quantizer, + ProxHardQuant(), + quant_per_channel=True, + ) + compare_parq_convert(model, m_ref, optimizer, config) class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): diff --git a/torchao/prototype/parq/quant/__init__.py b/torchao/prototype/parq/quant/__init__.py index c8b8365725..9b84d8bccf 100644 --- a/torchao/prototype/parq/quant/__init__.py +++ b/torchao/prototype/parq/quant/__init__.py @@ -13,5 +13,6 @@ ) from .uniform_torchao import ( # noqa: F401 Int4UnifTorchaoQuantizer, + StretchedUnifTorchaoQuantizer, UnifTorchaoQuantizer, ) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py new file mode 100644 index 0000000000..47dabb73f6 --- /dev/null +++ b/torchao/prototype/parq/quant/quant_api.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_UINT_BOUNDS, + MappingType, + ZeroPointDomain, + _get_reduction_params, + dequantize_affine, +) +from torchao.quantization.transform_module import register_quantize_module_handler + + +def choose_qparams_stretched_affine( + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + b: int, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale_dtype is None: + scale_dtype = input_float.dtype + if eps is None: + eps = torch.finfo(input_float.dtype).eps + if zero_point_dtype is None: + zero_point_dtype = input_float.dtype + + assert len(block_size) == input_float.dim(), f"Got {input.dim()=}, {block_size=}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + input_float = input_float.view(shape_for_reduction) + + q_abs = input_float.abs() + max_val = torch.minimum( + b * q_abs.mean(dim=reduction_dims, keepdim=True), + torch.amax(q_abs, dim=reduction_dims, keepdim=True), + ).clamp_(min=eps) + + scale = max_val / quant_max + scale = scale.to(dtype=scale_dtype, device=input_float.device) + zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype) + return scale, zero_point + + +def quantize_stretched_affine( + input_float: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor, + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, +) -> torch.Tensor: + if target_dtype in _SUB_BYTE_UINT_BOUNDS: + target_dtype = torch.uint8 + assert input_float.dtype in (torch.float32, torch.float16, torch.bfloat16), ( + f"Unsupported input_float dtype: {input_float.dtype}" + ) + assert len(block_size) == input_float.dim(), ( + f"Got {input_float.dim()=}, {block_size=}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + original_shape = input_float.shape + input_float = input_float.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None and zero_point.numel() > 0: + zero_point = zero_point.view(shape_after_reduction) + else: + zero_point = None + + max_val = scale.mul(quant_max) + input_float = input_float.clamp(min=-max_val, max=max_val) + with torch.no_grad(): + # difference from quantize_affine: add zero_point before rounding + quant = torch.round(input_float / scale + zero_point) + quant = quant.to(dtype=target_dtype).view(original_shape) + return quant + + +class StretchedAffineQuantizedTensor(AffineQuantizedTensor): + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + b: int, + quant_min: Optional[float] = None, + quant_max: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT, + _layout: Layout = QDQLayout(), # noqa: B008 + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + scale, zero_point = choose_qparams_stretched_affine( + input_float, + mapping_type, + block_size, + target_dtype, + b, + quant_min=quant_min, + quant_max=quant_max, + ) + data = quantize_stretched_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + data, scale, zero_point = _layout.post_process( + data, scale, zero_point, block_size + ) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + if not isinstance(self._layout, QDQLayout): + raise NotImplementedError( + f"StretchedAffineQuantizedTensor only supports QDQLayout but got {self._layout}" + ) + + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + output_dtype=output_dtype, + ) + return dq + + +to_stretched_affine_quantized_intx = StretchedAffineQuantizedTensor.from_hp_to_intx + + +@dataclass +class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig): + b: Optional[int] = None + quant_min: Optional[int] = None + quant_max: Optional[int] = None + + +@register_quantize_module_handler(StretchedIntxWeightOnlyConfig) +def _stretched_intx_weight_only_transform( + module: nn.Module, config: StretchedIntxWeightOnlyConfig +) -> nn.Module: + weight = module.weight + granularity = config.granularity + mapping_type = MappingType.ASYMMETRIC + + assert weight.dim() == 2, ( + f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + ) + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {granularity.axis}" + ) + group_size = weight.shape[-1] + else: + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + + weight = to_stretched_affine_quantized_intx( + input_float=weight, + mapping_type=mapping_type, + block_size=(1, group_size), + target_dtype=torch.int8, + b=config.b, + quant_min=config.quant_min, + quant_max=config.quant_max, + scale_dtype=config.scale_dtype, + _layout=config.layout, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + return module diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index ebe4e775e6..6d895452e8 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import math +from functools import partial from typing import Optional, Union import torch @@ -25,6 +27,10 @@ quantize_affine, ) +from .quant_api import ( + choose_qparams_stretched_affine, + quantize_stretched_affine, +) from .quantizer import Quantizer _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} @@ -56,17 +62,16 @@ def __init__( self._quantize = quantize_affine self._dequantize = dequantize_affine - if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_tinygemm - self._quantize = _quantize_affine_tinygemm - self._dequantize = _dequantize_affine_tinygemm - elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_dont_preserve_zero - self._quantize = quantize_affine - self._dequantize = dequantize_affine - elif zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE and not preserve_zero: self._quantize = _quantize_affine_no_zero_point self._dequantize = _dequantize_affine_no_zero_point + elif mapping_type == MappingType.ASYMMETRIC: + if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_tinygemm + self._quantize = _quantize_affine_tinygemm + self._dequantize = _dequantize_affine_tinygemm + elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_dont_preserve_zero def _init_quant_min_max(self, b: int) -> None: if self.quant_min is None or self.quant_max is None: @@ -113,9 +118,12 @@ def quantize( quant_max=self.quant_max, ) - Q = torch.arange( - self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device - ) + Q = torch.arange(self.quant_min, self.quant_max + 1e-5, device=p.device) + + if isinstance(self.quant_min, float): + Q = Q.floor() + Q = Q.to(dtype=self.target_dtype) + if dim is not None: Q = Q.view(1, -1).expand(q.size(0), -1) block_size = (1, Q.size(-1)) @@ -133,6 +141,26 @@ def quantize( return q, Q +class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer): + def __init__(self, b: int, int_shift: float = 0.5) -> None: + quant_absmax = 2 ** (b - 1) - int_shift + self.quant_min = -quant_absmax + self.quant_max = quant_absmax + self.int_shift = int_shift + + super().__init__( + mapping_type=MappingType.ASYMMETRIC, + quant_min=self.quant_min, + quant_max=self.quant_max, + ) + + self._choose_qparams = partial(choose_qparams_stretched_affine, b=b) + self._quantize = quantize_stretched_affine + + def get_quant_size(self, b: int) -> int: + return math.floor(2**b - 2 * self.int_shift) + 1 + + class Int4UnifTorchaoQuantizer(UnifTorchaoQuantizer): """Based on torchao.quantization.quant_api._int4_weight_only_transform"""