From eea8bc36bfcf9ef822707425aeaab60bbd469a3f Mon Sep 17 00:00:00 2001 From: lisjin Date: Mon, 21 Jul 2025 08:58:25 -0700 Subject: [PATCH 1/4] Add StretchedUnifTorchaoQuantizer --- test/prototype/test_parq.py | 172 +++++++++++------ torchao/prototype/parq/quant/__init__.py | 1 + torchao/prototype/parq/quant/quant_api.py | 175 ++++++++++++++++++ .../prototype/parq/quant/uniform_torchao.py | 63 ++++++- 4 files changed, 343 insertions(+), 68 deletions(-) create mode 100644 torchao/prototype/parq/quant/quant_api.py diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 68c25821ee..d3881dd565 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -21,10 +21,12 @@ from torchao.prototype.parq.quant import ( Int4UnifTorchaoQuantizer, LSBQuantizer, + StretchedUnifTorchaoQuantizer, TernaryUnifQuantizer, UnifQuantizer, UnifTorchaoQuantizer, ) +from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import ( @@ -35,11 +37,11 @@ from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - MappingType, _is_linear, int4_weight_only, quantize_, ) +from torchao.quantization.quant_primitives import MappingType from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6, @@ -74,6 +76,59 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): ] +def compare_quantized_models( + model: nn.Module, + m_ref: nn.Module, + quantizer: UnifTorchaoQuantizer, + b: int, + group_size: int, +): + for n, module in model.named_children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + original_shape = p.shape + p = p.view(-1, group_size) + + q, Q = quantizer.quantize(p, b=b, dim=-1) + + # compare to AffineQuantizedTensor instance + q = q.view(original_shape) + ref = getattr(m_ref, n).weight.dequantize() + torch.testing.assert_close(q, ref, atol=0, rtol=0) + + +def compare_parq_convert( + model: nn.Module, + m_ref: nn.Module, + optimizer: QuantOptimizer, + config: AOBaseConfig, +): + # do not update model weights, just quantize + optimizer.zero_grad() + optimizer.step() + + orig_model = copy.deepcopy(model) # save copy of PARQ quantized model + + # equivalent to torchao's convert step + model.eval() + optimizer.restore_latent_params() + quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) + + for n, module in model.named_modules(): + if not _is_linear(module): + continue + + p_orig = getattr(orig_model, n).weight # PARQ weight + p = module.weight.dequantize() # PARQ weight after quantize_ + p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ + + torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) + torch.testing.assert_true(p, p_ref, atol=0, rtol=0) + + class M(nn.Module): def __init__(self, m=256, n=128, k=16, bias=False, embedding=True): super().__init__() @@ -143,59 +198,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - def compare_quantized_models( - self, - model: nn.Module, - m_ref: nn.Module, - quantizer: UnifTorchaoQuantizer, - b: int, - group_size: int, - ): - for n, module in model.named_children(): - if not _is_linear(module): - continue - - # simulate grouping from QuantOptimizer.step - p = module.weight - original_shape = p.shape - p = p.view(-1, group_size) - - q, Q = quantizer.quantize(p, b=b, dim=-1) - - # compare to AffineQuantizedTensor instance - q = q.view(original_shape) - ref = getattr(m_ref, n).weight.dequantize() - torch.testing.assert_close(q, ref, atol=0, rtol=0) - - def compare_parq_convert( - self, - model: nn.Module, - m_ref: nn.Module, - optimizer: QuantOptimizer, - config: AOBaseConfig, - ): - # do not update model weights, just quantize - optimizer.zero_grad() - optimizer.step() - - orig_model = copy.deepcopy(model) # save copy of PARQ quantized model - - # equivalent to torchao's convert step - model.eval() - optimizer.restore_latent_params() - quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) - - for n, module in model.named_modules(): - if not _is_linear(module): - continue - - p_orig = getattr(orig_model, n).weight # PARQ weight - p = module.weight.dequantize() # PARQ weight after quantize_ - p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ - - torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) - torch.testing.assert_true(p, p_ref, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): @@ -209,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32): quantize_(m_ref, config) b = 4 - self.compare_quantized_models( + compare_quantized_models( model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size ) @@ -229,7 +231,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): ) quantizer = UnifTorchaoQuantizer() - self.compare_quantized_models(model, m_ref, quantizer, b, group_size) + compare_quantized_models(model, m_ref, quantizer, b, group_size) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @@ -251,7 +253,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, config) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @@ -273,7 +275,61 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, config) + + +class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): + def setUp(self): + torch.manual_seed(123) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 512]) + def test_intx_weight_only(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + quantize_( + m_ref, + StretchedIntxWeightOnlyConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + ), + ) + + compare_quantized_models(model, m_ref, quantizer, b, group_size) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") + @common_utils.parametrize("b", [2, 3]) + def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + config = StretchedIntxWeightOnlyConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + ) + quantize_(m_ref, config) + + base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optimizer = QuantOptimizer( + base_optimizer, + quantizer, + ProxHardQuant(), + quant_per_channel=True, + ) + compare_parq_convert(model, m_ref, optimizer, config) class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): diff --git a/torchao/prototype/parq/quant/__init__.py b/torchao/prototype/parq/quant/__init__.py index c8b8365725..9b84d8bccf 100644 --- a/torchao/prototype/parq/quant/__init__.py +++ b/torchao/prototype/parq/quant/__init__.py @@ -13,5 +13,6 @@ ) from .uniform_torchao import ( # noqa: F401 Int4UnifTorchaoQuantizer, + StretchedUnifTorchaoQuantizer, UnifTorchaoQuantizer, ) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py new file mode 100644 index 0000000000..7d42929131 --- /dev/null +++ b/torchao/prototype/parq/quant/quant_api.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from torchao.dtypes import AffineQuantizedTensor, Layout, PlainLayout, QDQLayout +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine_with_min_max, + dequantize_affine, + quantize_affine, +) +from torchao.quantization.transform_module import register_quantize_module_handler + +from .uniform import get_q_max + + +class StretchedAffineQuantizedTensor(AffineQuantizedTensor): + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + b: int, + quant_min: Optional[float] = None, + quant_max: Optional[float] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = False, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT, + _layout: Layout = PlainLayout(), # noqa: B008 + scale_method: str = "mean", + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + dim = None + qmax_shape = [] + for d, size in enumerate(block_size): + if size > 1: + dim = d + qmax_shape.append(original_shape[d] // size) + else: + qmax_shape.append(original_shape[d]) + assert dim is not None, ( + "block_size must have at least one dimension greater than 1" + ) + reduction_shape = [-1 if i != dim else b for i, b in enumerate(block_size)] + input_float = input_float.view(reduction_shape) + q_max = get_q_max(input_float, b, dim=dim, scale_method=scale_method) + q_max = q_max.clamp(min=torch.finfo(input_float.dtype).tiny) + q_max = q_max.view(qmax_shape) + input_float = input_float.view(original_shape) + + scale, zero_point = choose_qparams_affine_with_min_max( + -q_max, + q_max, + mapping_type, + block_size, + target_dtype, + eps=eps, + quant_min=quant_min, + quant_max=quant_max, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + data, scale, zero_point = _layout.post_process( + data, scale, zero_point, block_size + ) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + if not isinstance(self._layout, QDQLayout): + raise NotImplementedError( + f"StretchedAffineQuantizedTensor only supports QDQLayout but got {self._layout}" + ) + + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + output_dtype=output_dtype, + ) + return dq + + +to_stretched_affine_quantized_intx = StretchedAffineQuantizedTensor.from_hp_to_intx + + +@dataclass +class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig): + b: Optional[int] = None + quant_min: Optional[float] = None + quant_max: Optional[float] = None + mapping_type: MappingType = MappingType.ASYMMETRIC + zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT + + +@register_quantize_module_handler(StretchedIntxWeightOnlyConfig) +def _stretched_intx_weight_only_transform( + module: nn.Module, config: StretchedIntxWeightOnlyConfig +) -> nn.Module: + weight = module.weight + granularity = config.granularity + mapping_type = config.mapping_type + + assert weight.dim() == 2, ( + f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + ) + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {granularity.axis}" + ) + group_size = weight.shape[-1] + else: + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + + weight = to_stretched_affine_quantized_intx( + input_float=weight, + mapping_type=mapping_type, + block_size=(1, group_size), + target_dtype=torch.int8, + b=config.b, + quant_min=config.quant_min, + quant_max=config.quant_max, + scale_dtype=config.scale_dtype, + zero_point_dtype=torch.int8, + preserve_zero=(mapping_type == MappingType.SYMMETRIC), + zero_point_domain=config.zero_point_domain, + _layout=config.layout, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + return module diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index ebe4e775e6..247ad31eef 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import math +from functools import partial from typing import Optional, Union import torch @@ -21,11 +23,13 @@ _quantize_affine_no_zero_point, _quantize_affine_tinygemm, choose_qparams_affine, + choose_qparams_affine_with_min_max, dequantize_affine, quantize_affine, ) from .quantizer import Quantizer +from .uniform import get_q_max _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} @@ -42,6 +46,7 @@ def __init__( eps: Optional[float] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + scale_method: str = "max", ) -> None: super().__init__(center=False) @@ -50,23 +55,29 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.eps = eps + self.scale_method = scale_method # defaults: zero_point_domain=ZeroPointDomain.INT, preserve_zero=True self._choose_qparams = choose_qparams_affine self._quantize = quantize_affine self._dequantize = dequantize_affine - if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_tinygemm - self._quantize = _quantize_affine_tinygemm - self._dequantize = _dequantize_affine_tinygemm - elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_dont_preserve_zero - self._quantize = quantize_affine - self._dequantize = dequantize_affine - elif zero_point_domain == ZeroPointDomain.NONE: + if quant_min is not None and quant_max is not None: + self._choose_qparams = partial( + choose_qparams_affine_with_min_max, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + elif zero_point_domain == ZeroPointDomain.NONE and not preserve_zero: self._quantize = _quantize_affine_no_zero_point self._dequantize = _dequantize_affine_no_zero_point + elif mapping_type == MappingType.ASYMMETRIC: + if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_tinygemm + self._quantize = _quantize_affine_tinygemm + self._dequantize = _dequantize_affine_tinygemm + elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_dont_preserve_zero def _init_quant_min_max(self, b: int) -> None: if self.quant_min is None or self.quant_max is None: @@ -89,8 +100,17 @@ def quantize( # assume that p has already been grouped in QuantOptimizer.step block_size = (1, p.size(-1)) if dim is not None else p.size() + if ( + getattr(self._choose_qparams, "func", None) + == choose_qparams_affine_with_min_max + ): + q_max = get_q_max(p, b, dim=dim, scale_method=self.scale_method) + q_max = q_max.clamp_(min=torch.finfo(p.dtype).tiny) + q_args = (-q_max, q_max) + else: + q_args = (p,) s, zero_point = self._choose_qparams( - p, + *q_args, self.mapping_type, block_size, self.target_dtype, @@ -133,6 +153,29 @@ def quantize( return q, Q +class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer): + def __init__( + self, b: int, int_shift: float = 0.5, scale_method: str = "mean" + ) -> None: + # use choose_qparams_affine_with_min_max to infer zero_point + quant_absmax = 2 ** (b - 1) - int_shift + self.quant_min = -quant_absmax + self.quant_max = quant_absmax + self.int_shift = int_shift + + super().__init__( + mapping_type=MappingType.ASYMMETRIC, + quant_min=self.quant_min, + quant_max=self.quant_max, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, + scale_method=scale_method, + ) + + def get_quant_size(self, b: int) -> int: + return math.floor(2**b - 2 * self.int_shift) + 1 + + class Int4UnifTorchaoQuantizer(UnifTorchaoQuantizer): """Based on torchao.quantization.quant_api._int4_weight_only_transform""" From 64f49cc44ff8808ec9d1739aebed7e8e84616deb Mon Sep 17 00:00:00 2001 From: lisjin Date: Mon, 21 Jul 2025 10:42:52 -0700 Subject: [PATCH 2/4] Fix tinygemm test case --- torchao/prototype/parq/quant/uniform_torchao.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index 247ad31eef..9389c797ec 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -62,13 +62,7 @@ def __init__( self._quantize = quantize_affine self._dequantize = dequantize_affine - if quant_min is not None and quant_max is not None: - self._choose_qparams = partial( - choose_qparams_affine_with_min_max, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - elif zero_point_domain == ZeroPointDomain.NONE and not preserve_zero: + if zero_point_domain == ZeroPointDomain.NONE and not preserve_zero: self._quantize = _quantize_affine_no_zero_point self._dequantize = _dequantize_affine_no_zero_point elif mapping_type == MappingType.ASYMMETRIC: @@ -167,9 +161,13 @@ def __init__( mapping_type=MappingType.ASYMMETRIC, quant_min=self.quant_min, quant_max=self.quant_max, + scale_method=scale_method, + ) + + self._choose_qparams = partial( + choose_qparams_affine_with_min_max, preserve_zero=False, zero_point_domain=ZeroPointDomain.FLOAT, - scale_method=scale_method, ) def get_quant_size(self, b: int) -> int: From bc86a87a2feb110d185ecf6f709fe20c7b88f0c7 Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 22 Jul 2025 12:29:48 -0700 Subject: [PATCH 3/4] Test equivalence to PARQ UnifQuantizer; custom choose_qparams, quantize, dequantize --- test/prototype/test_parq.py | 23 +++ torchao/prototype/parq/quant/quant_api.py | 156 +++++++++++++----- .../prototype/parq/quant/uniform_torchao.py | 43 ++--- 3 files changed, 151 insertions(+), 71 deletions(-) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index d3881dd565..36765fb9b5 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -282,6 +282,29 @@ class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 256]) + def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + model.reset_parameters() + + quantizer_ref = UnifQuantizer() + quantizer = StretchedUnifTorchaoQuantizer(b) + + for n, module in model.named_children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + p = p.view(-1, group_size) + + q_ref, Q_ref = quantizer_ref.quantize(p, b=b, dim=-1) + q, Q = quantizer.quantize(p, b=b, dim=-1) + + torch.testing.assert_close(q, q_ref, atol=0, rtol=0) + torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3]) @common_utils.parametrize("group_size", [32, 512]) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index 7d42929131..87c05bc03d 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -5,24 +5,122 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn -from torchao.dtypes import AffineQuantizedTensor, Layout, PlainLayout, QDQLayout +from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig from torchao.quantization.quant_primitives import ( + _SUB_BYTE_UINT_BOUNDS, MappingType, ZeroPointDomain, - choose_qparams_affine_with_min_max, + _get_reduction_params, dequantize_affine, - quantize_affine, ) from torchao.quantization.transform_module import register_quantize_module_handler -from .uniform import get_q_max + +def choose_qparams_stretched_affine( + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + b: int, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale_dtype is None: + scale_dtype = input_float.dtype + if eps is None: + eps = torch.finfo(input_float.dtype).eps + if zero_point_dtype is None: + zero_point_dtype = input_float.dtype + + assert len(block_size) == input_float.dim(), f"Got {input.dim()=}, {block_size=}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + input_float = input_float.view(shape_for_reduction) + + q_abs = input_float.abs() + max_val = torch.minimum( + b * q_abs.mean(dim=reduction_dims, keepdim=True), + torch.amax(q_abs, dim=reduction_dims, keepdim=True), + ).clamp_(min=eps) + + scale = max_val / quant_max + scale = scale.to(dtype=scale_dtype, device=input_float.device) + zero_point = torch.full_like(scale, 0.5, dtype=zero_point_dtype) + return scale, zero_point + + +def quantize_stretched_affine( + input_float: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor, + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, +) -> torch.Tensor: + if target_dtype in _SUB_BYTE_UINT_BOUNDS: + target_dtype = torch.uint8 + assert input_float.dtype in (torch.float32, torch.float16, torch.bfloat16), ( + f"Unsupported input_float dtype: {input_float.dtype}" + ) + assert len(block_size) == input_float.dim(), ( + f"Got {input_float.dim()=}, {block_size=}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + original_shape = input_float.shape + input_float = input_float.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None and zero_point.numel() > 0: + zero_point = zero_point.view(shape_after_reduction) + else: + zero_point = None + + max_val = scale.mul(quant_max) + input_float = input_float.clamp(min=-max_val, max=max_val) + with torch.no_grad(): + quant = torch.round(input_float / scale - zero_point) + quant = quant.to(dtype=target_dtype).view(original_shape) + return quant + + +def dequantize_stretched_affine( + data: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor, + data_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + # allow float data_dtype instead of restricting to _SUB_BYTE_UINT_BOUNDS + return dequantize_affine( + data, + block_size, + scale, + -zero_point, + data_dtype, + quant_min=quant_min, + quant_max=quant_max, + output_dtype=output_dtype, + ) class StretchedAffineQuantizedTensor(AffineQuantizedTensor): @@ -36,48 +134,23 @@ def from_hp_to_intx( b: int, quant_min: Optional[float] = None, quant_max: Optional[float] = None, - eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: bool = False, zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT, - _layout: Layout = PlainLayout(), # noqa: B008 - scale_method: str = "mean", + _layout: Layout = QDQLayout(), # noqa: B008 ): original_shape = input_float.shape input_float = _layout.pre_process(input_float) - dim = None - qmax_shape = [] - for d, size in enumerate(block_size): - if size > 1: - dim = d - qmax_shape.append(original_shape[d] // size) - else: - qmax_shape.append(original_shape[d]) - assert dim is not None, ( - "block_size must have at least one dimension greater than 1" - ) - reduction_shape = [-1 if i != dim else b for i, b in enumerate(block_size)] - input_float = input_float.view(reduction_shape) - q_max = get_q_max(input_float, b, dim=dim, scale_method=scale_method) - q_max = q_max.clamp(min=torch.finfo(input_float.dtype).tiny) - q_max = q_max.view(qmax_shape) - input_float = input_float.view(original_shape) - - scale, zero_point = choose_qparams_affine_with_min_max( - -q_max, - q_max, + scale, zero_point = choose_qparams_stretched_affine( + input_float, mapping_type, block_size, target_dtype, - eps=eps, + b, quant_min=quant_min, quant_max=quant_max, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, ) - data = quantize_affine( + data = quantize_stretched_affine( input_float, block_size, scale, @@ -111,7 +184,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor ) data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( + dq = dequantize_stretched_affine( data, self.block_size, scale, @@ -130,10 +203,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @dataclass class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig): b: Optional[int] = None - quant_min: Optional[float] = None - quant_max: Optional[float] = None - mapping_type: MappingType = MappingType.ASYMMETRIC - zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT + quant_min: Optional[int] = None + quant_max: Optional[int] = None @register_quantize_module_handler(StretchedIntxWeightOnlyConfig) @@ -142,7 +213,7 @@ def _stretched_intx_weight_only_transform( ) -> nn.Module: weight = module.weight granularity = config.granularity - mapping_type = config.mapping_type + mapping_type = MappingType.ASYMMETRIC assert weight.dim() == 2, ( f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -166,9 +237,6 @@ def _stretched_intx_weight_only_transform( quant_min=config.quant_min, quant_max=config.quant_max, scale_dtype=config.scale_dtype, - zero_point_dtype=torch.int8, - preserve_zero=(mapping_type == MappingType.SYMMETRIC), - zero_point_domain=config.zero_point_domain, _layout=config.layout, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index 9389c797ec..bb5c59fad0 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -23,13 +23,16 @@ _quantize_affine_no_zero_point, _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_with_min_max, dequantize_affine, quantize_affine, ) +from .quant_api import ( + choose_qparams_stretched_affine, + dequantize_stretched_affine, + quantize_stretched_affine, +) from .quantizer import Quantizer -from .uniform import get_q_max _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} @@ -46,7 +49,6 @@ def __init__( eps: Optional[float] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - scale_method: str = "max", ) -> None: super().__init__(center=False) @@ -55,7 +57,6 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.eps = eps - self.scale_method = scale_method # defaults: zero_point_domain=ZeroPointDomain.INT, preserve_zero=True self._choose_qparams = choose_qparams_affine @@ -94,17 +95,8 @@ def quantize( # assume that p has already been grouped in QuantOptimizer.step block_size = (1, p.size(-1)) if dim is not None else p.size() - if ( - getattr(self._choose_qparams, "func", None) - == choose_qparams_affine_with_min_max - ): - q_max = get_q_max(p, b, dim=dim, scale_method=self.scale_method) - q_max = q_max.clamp_(min=torch.finfo(p.dtype).tiny) - q_args = (-q_max, q_max) - else: - q_args = (p,) s, zero_point = self._choose_qparams( - *q_args, + p, self.mapping_type, block_size, self.target_dtype, @@ -127,9 +119,12 @@ def quantize( quant_max=self.quant_max, ) - Q = torch.arange( - self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device - ) + Q = torch.arange(self.quant_min, self.quant_max + 1e-5, device=p.device) + + if isinstance(self.quant_min, float): + Q = Q.floor() + Q = Q.to(dtype=self.target_dtype) + if dim is not None: Q = Q.view(1, -1).expand(q.size(0), -1) block_size = (1, Q.size(-1)) @@ -148,10 +143,7 @@ def quantize( class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer): - def __init__( - self, b: int, int_shift: float = 0.5, scale_method: str = "mean" - ) -> None: - # use choose_qparams_affine_with_min_max to infer zero_point + def __init__(self, b: int, int_shift: float = 0.5) -> None: quant_absmax = 2 ** (b - 1) - int_shift self.quant_min = -quant_absmax self.quant_max = quant_absmax @@ -161,14 +153,11 @@ def __init__( mapping_type=MappingType.ASYMMETRIC, quant_min=self.quant_min, quant_max=self.quant_max, - scale_method=scale_method, ) - self._choose_qparams = partial( - choose_qparams_affine_with_min_max, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + self._choose_qparams = partial(choose_qparams_stretched_affine, b=b) + self._quantize = quantize_stretched_affine + self._dequantize = dequantize_stretched_affine def get_quant_size(self, b: int) -> int: return math.floor(2**b - 2 * self.int_shift) + 1 From 6bdd3f61f961f1e8c7e760973ba43852c9fd50cc Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 22 Jul 2025 12:54:19 -0700 Subject: [PATCH 4/4] Remove dequantize_stretched_affine --- torchao/prototype/parq/quant/quant_api.py | 30 +++---------------- .../prototype/parq/quant/uniform_torchao.py | 2 -- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index 87c05bc03d..47dabb73f6 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -56,7 +56,7 @@ def choose_qparams_stretched_affine( scale = max_val / quant_max scale = scale.to(dtype=scale_dtype, device=input_float.device) - zero_point = torch.full_like(scale, 0.5, dtype=zero_point_dtype) + zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype) return scale, zero_point @@ -95,34 +95,12 @@ def quantize_stretched_affine( max_val = scale.mul(quant_max) input_float = input_float.clamp(min=-max_val, max=max_val) with torch.no_grad(): - quant = torch.round(input_float / scale - zero_point) + # difference from quantize_affine: add zero_point before rounding + quant = torch.round(input_float / scale + zero_point) quant = quant.to(dtype=target_dtype).view(original_shape) return quant -def dequantize_stretched_affine( - data: torch.Tensor, - block_size: Tuple[int, ...], - scale: torch.Tensor, - zero_point: torch.Tensor, - data_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - # allow float data_dtype instead of restricting to _SUB_BYTE_UINT_BOUNDS - return dequantize_affine( - data, - block_size, - scale, - -zero_point, - data_dtype, - quant_min=quant_min, - quant_max=quant_max, - output_dtype=output_dtype, - ) - - class StretchedAffineQuantizedTensor(AffineQuantizedTensor): @classmethod def from_hp_to_intx( @@ -184,7 +162,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor ) data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_stretched_affine( + dq = dequantize_affine( data, self.block_size, scale, diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index bb5c59fad0..6d895452e8 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -29,7 +29,6 @@ from .quant_api import ( choose_qparams_stretched_affine, - dequantize_stretched_affine, quantize_stretched_affine, ) from .quantizer import Quantizer @@ -157,7 +156,6 @@ def __init__(self, b: int, int_shift: float = 0.5) -> None: self._choose_qparams = partial(choose_qparams_stretched_affine, b=b) self._quantize = quantize_stretched_affine - self._dequantize = dequantize_stretched_affine def get_quant_size(self, b: int) -> int: return math.floor(2**b - 2 * self.int_shift) + 1