From 0756f391c0dd4dc9382bf961490489a6eb16489a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 4 Oct 2024 15:52:27 -0700 Subject: [PATCH 01/16] Make module swap the main QAT flow again Summary: Following https://github.com/pytorch/ao/issues/987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] --- test/quantization/test_qat.py | 141 ++----- .../quantization/prototype/qat/__init__.py | 13 +- .../prototype/qat/_module_swap_api.py | 364 +---------------- torchao/quantization/prototype/qat/api.py | 229 +---------- torchao/quantization/prototype/qat/linear.py | 377 ++++++++++++++++++ torchao/quantization/prototype/qat/utils.py | 79 ---- 6 files changed, 418 insertions(+), 785 deletions(-) create mode 100644 torchao/quantization/prototype/qat/linear.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 72ffc23ab6..e1e670d5da 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,15 +18,11 @@ from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, ) -from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, -) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, _GenericFakeQuantize, - _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, ) from torchao.quantization.quant_api import ( int4_weight_only, @@ -164,7 +160,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat._module_swap_api import ( + from torchao.quantization.prototype.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -196,7 +192,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer - - group_size = 16 - torch.manual_seed(self.SEED) - m = M() - m2 = copy.deepcopy(m) - qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) - qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) - - # Compare model values - torch.manual_seed(self.SEED) - x = m.example_inputs() - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - def test_qat_8da4w_quantizer_module_swap(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size) + module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) subclass_model = subclass_quantizer.prepare(m) module_swap_model = module_swap_quantizer.prepare(m2) @@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - def _copy_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) - - def _assert_matches_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) - def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): - self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor)) - self.assertEqual(m.weight.fake_quant_enabled, enabled) - self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)) - (_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enabled: - self.assertIsNotNone(handle) - else: - self.assertIsNone(handle) - group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=False) - assert_fake_quant_enabled(qat_model.linear2, enabled=False) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) + self.assertFalse(qat_model.linear1._fake_quant_enabled) + self.assertFalse(qat_model.linear2._fake_quant_enabled) + self.assertFalse(qat_model.sub.linear._fake_quant_enabled) # Disabled fake quant is just a normal linear - self._copy_subclass_weights(m2.linear1, qat_model.linear1) - self._copy_subclass_weights(m2.linear2, qat_model.linear2) - self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) + m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=True) - assert_fake_quant_enabled(qat_model.linear2, enabled=True) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) + self.assertTrue(qat_model.linear1._fake_quant_enabled) + self.assertTrue(qat_model.linear2._fake_quant_enabled) + self.assertTrue(qat_model.sub.linear._fake_quant_enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model2 = quantizer2.prepare(m3) - qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor - qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor - qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor + qat_model2.linear1.weight = qat_model.linear1.weight + qat_model2.linear2.weight = qat_model.linear2.weight + qat_model2.sub.linear.weight = qat_model.sub.linear.weight torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) - self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) - self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + nn_model.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + nn_model.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) - self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) - self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) def _test_qat_quantized_gradients(self, quantizer): """ @@ -542,7 +486,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -567,39 +511,6 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer - - group_size = 32 - inner_k_tiles = 8 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - m = M().to(device).to(dtype) - m2 = copy.deepcopy(m) - qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, - ) - qat_model = qat_quantizer.prepare(m) - ptq_model = m2 - quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles))) - - # Compare model values - torch.manual_seed(self.SEED) - x = [i.to(device).to(dtype) for i in m.example_inputs()] - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - self._assert_close_4w(qat_out, ptq_out) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer @@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer_module_swap(self): + def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self): subclass_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap( + module_swap_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) subclass_model = subclass_quantizer.prepare(m) diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index 62740839b7..09ea6e708d 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -1,17 +1,14 @@ from .api import ( + ComposableQATQuantizer, +) +from .linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, - int4_weight_only_fake_quantize, - int8_dynamic_activation_int4_weight_fake_quantize, - ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, -) - -from ._module_swap_api import ( Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, ) from .embedding import ( Int4WeightOnlyEmbeddingQATQuantizer, @@ -22,8 +19,6 @@ "disable_8da4w_fake_quant", "enable_4w_fake_quant", "enable_8da4w_fake_quant", - "int4_weight_only_fake_quantize", - "int8_dynamic_activation_int4_weight_fake_quantize", "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer" diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py index a9239a03d5..0b44974f21 100644 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ b/torchao/quantization/prototype/qat/_module_swap_api.py @@ -1,355 +1,11 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any - -import torch -import torch.nn.functional as F - -from torchao.quantization.GPTQ import ( - _check_linear_int4_k, - _replace_linear_int4, - _replace_linear_8da4w, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor, - Int8DynActInt4WeightLinear, - WeightOnlyInt4Linear, -) -from torchao.quantization.quant_primitives import ZeroPointDomain -from torchao.quantization.utils import get_group_qparams_symmetric -from .api import ( - Int8DynActInt4WeightQATQuantizer, - Int4WeightOnlyQATQuantizer, -) -from .utils import ( - _choose_qparams_per_token_asymmetric, - _fake_quantize_per_channel_group, - _fake_quantize_per_token, - _get_qmin_qmax, +# For backward compatibility only +# These will be removed in the future + +from .linear import ( + Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, + Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, + enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, + disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, + enable_4w_fake_quant as enable_4w_fake_quant_module_swap, + disable_4w_fake_quant as disable_4w_fake_quant_module_swap, ) - - -# TODO: make module swap the main flow again, and remove the quantize_ flow -# TODO: rename this file to linear.py - -# ========================================================= -# | Linear int8 dynamic activations + int4 weight QAT | -# ========================================================= - - -class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) - return model - - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int8 dynamic per token fake - quantized activations with int4 fake quantized grouped per channel weights. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - - -def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() - - -def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() - - -# =================================== -# | Linear int4 weight-only QAT | -# =================================== - - -class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_4w(model) - return model - - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int4 fake quantized grouped - per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, - which uses the efficient int4 tinygemm kernel. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - inner_k_tiles: int = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - - -def enable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() - - -def disable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index e1c5221e1e..93717271bb 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,34 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional +from typing import Any, List import torch -import torch.nn.functional as F -from torchao.dtypes import ( - TensorCoreTiledLayoutType, -) -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int4_weight, - quantize_, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import _get_per_token_block_size -from .affine_fake_quantized_tensor import to_affine_fake_quantized -from .utils import ( - _enable_fake_quant, - _get_qat_linear_subclass_inserter, - _is_linear_with_fq_weight, - _unwrap_affine_fake_quantized_tensor, -) class ComposableQATQuantizer(TwoStepQuantizer): @@ -70,207 +47,3 @@ def convert( for quantizer in self.quantizers: model = quantizer.convert(model) return model - - -# ================= -# | 8da4w QAT | -# ================= - -def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): - """ - Applies int8 dynamic per token asymmetric activation fake quantization and - int4 per group weight symmetric fake quantization to linear. Please see - :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) - """ - # avoid circular dep - from torchao.dtypes import to_affine_quantized_intx - - def _apply_weight_fake_quant(weight: torch.Tensor): - mapping_type = MappingType.SYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - quant_min = -8 - quant_max = 7 - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - ) - - def _apply_input_activation_fake_quant(x: torch.Tensor): - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int8 - return to_affine_fake_quantized( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - ) - - return _get_qat_linear_subclass_inserter( - _apply_weight_fake_quant, - _apply_input_activation_fake_quant, - ) - -class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - self.groupsize: int = groupsize - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_( - model, - int8_dynamic_activation_int4_weight_fake_quantize(group_size=self.groupsize), - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) - quantize_(model, quantize_fn) - return model - - -def enable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=True) - -def disable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=False) - - -# ================== -# | int4wo QAT | -# ================== - -def int4_weight_only_fake_quantize(group_size=128): - """ - Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. - Please see :func:`~torchao.quantization.int4_weight_only` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int4_weight_only_fake_quantize(group_size=32)) - """ - def _apply_fake_quant(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - return _get_qat_linear_subclass_inserter(_apply_fake_quant) - -class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - inner_k_tiles: Optional[int] = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__() - assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] - self.inner_k_tiles = inner_k_tiles - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_(model, int4_weight_only_fake_quantize(group_size=self.groupsize)) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) - quantize_fn = int4_weight_only(self.groupsize, layout_type) - quantize_(model, quantize_fn) - return model - -def enable_4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=True) - -def disable_4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py new file mode 100644 index 0000000000..07276ba84c --- /dev/null +++ b/torchao/quantization/prototype/qat/linear.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, + _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, +) +from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +# ========================================================= +# | Linear int8 dynamic activations + int4 weight QAT | +# ========================================================= + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.scales_precision, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_8da4w(model) + return model + + +def _convert_qat_linear_8da4w(module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + _convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + assert not bias, "require bias=False" + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + # TODO: make this configurable? + self.zero_points_precision = torch.int32 + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # activations: int8 dynamic asymmetric quant + if self._fake_quant_enabled: + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, self.scales_precision, self.zero_points_precision, + ) + (act_qmin, act_qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + else: + x_fq = x + + # weights: int4 grouped per channel symmetric quant + if self._fake_quant_enabled: + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.scales_precision, + ) + # TODO: pass zp dtype to `get_group_qparams_symmetric` instead + weight_zp = weight_zp.to(self.zero_points_precision) + (weight_qmin, weight_qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + else: + w_fq = self.weight + return F.linear(x_fq, w_fq) + + +def enable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.enable_fake_quant() + + +def disable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.disable_fake_quant() + + +# =================================== +# | Linear int4 weight-only QAT | +# =================================== + + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _fake_quantize_per_channel_group( + self.weight, + scales, + zero_points, + qmin, + qmax, + self.groupsize, + ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 1e4b61b8ac..354475e655 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,85 +181,6 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _forward_pre_hook_handler( - mod: torch.nn.Linear, - prehook: Callable, - handler: torch.utils.hooks.RemovableHandle, -): - """ - Store a 2-tuple (prehook function, handler) as an attribute on the given linear module. - """ - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handler)) - -def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): - """ - Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - assert isinstance(t, AffineFakeQuantizedTensor) - return t.original_tensor - -def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): - """ - Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): - return False - weight = mod.weight - return isinstance(weight, AffineFakeQuantizedTensor) - -def _enable_fake_quant(mod: torch.nn.Module, enable: bool): - """ - Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. - """ - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not _is_linear_with_fq_weight(mod): - return - weight = mod.weight - assert isinstance(weight, AffineFakeQuantizedTensor) - weight.fake_quant_enabled = enable - - # Enable/disable input fake quant - if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): - (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enable and handle is None: - handle = mod.register_forward_pre_hook(prehook) - elif not enable and handle is not None: - handle.remove() - handle = None - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - -def _get_qat_linear_subclass_inserter( - weight_constructor: Callable, - input_constructor: Optional[Callable] = None, -) -> Callable: - """ - Return a function that inserts wraps the weight and/or input activation of a - linear module in tensor subclasses. - - Args: - weight_constructor: constructor of the weight subclass, accepts a tensor - input_constructor: (optional) constructor of the input subclass, accepts a tensor - """ - def insert_subclass(lin): - lin.weight = torch.nn.Parameter(weight_constructor(lin.weight), requires_grad=True) - if input_constructor is not None: - prehook = lambda _, args: tuple([input_constructor(args[0])] + list(args[1:])) - handle = lin.register_forward_pre_hook(prehook) - setattr(lin, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - return lin - - return insert_subclass - def _get_qmin_qmax(n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1 From 9e9fdef39eb39a7181b2169855e12d4cc053ef0e Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 4 Oct 2024 15:52:31 -0700 Subject: [PATCH 02/16] Add generic fake quantized linear for QAT Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] --- test/quantization/test_qat.py | 229 +++++++++--- torchao/quantization/prototype/qat/api.py | 66 +++- .../prototype/qat/fake_quantizer.py | 116 +++++++ torchao/quantization/prototype/qat/linear.py | 327 ++++++++++-------- torchao/quantization/prototype/qat/utils.py | 10 +- 5 files changed, 557 insertions(+), 191 deletions(-) create mode 100644 torchao/quantization/prototype/qat/fake_quantizer.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index e1e670d5da..67a59965f0 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -11,17 +11,27 @@ import unittest import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao.dtypes import ( TensorCoreTiledLayoutType, ) from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, + FakeQuantizeConfig, + QuantizationGranularity, +) +from torchao.quantization.prototype.qat.fake_quantizer import ( + FakeQuantizer, +) +from torchao.quantization.prototype.qat.linear import ( + FakeQuantizedLinear, ) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qmin_qmax, _GenericFakeQuantize, ) from torchao.quantization.quant_api import ( @@ -92,15 +102,10 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - def _get_qmin_qmax(self, n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 - return (qmin, qmax) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_channel_group(self): n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = 128 torch.manual_seed(self.SEED) @@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_token(self): - (qmin, qmax) = self._get_qmin_qmax(8) + (qmin, qmax) = _get_qmin_qmax(8) torch.manual_seed(self.SEED) x = torch.randn(100, 256).requires_grad_() @@ -165,11 +170,11 @@ def _set_ptq_weight( Int4WeightOnlyQATLinear, ) n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) + group_size = qat_linear.weight_fake_quantizer.config.group_size if isinstance(ptq_linear, Int8DynActInt4WeightLinear): assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) fp32_weight = qat_linear.weight - group_size = qat_linear.groupsize (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, @@ -180,7 +185,7 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, qat_linear.groupsize, + qat_linear.weight, n_bit, group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( q_weight.to("cuda"), qat_linear.inner_k_tiles, @@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) - subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): @@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled) # Disabled fake quant is just a normal linear m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) @@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) @@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self): the numerics of existing fake quantize ops in Pytorch in both the forward and the backward passes. """ - (qmin, qmax) = self._get_qmin_qmax(4) + (qmin, qmax) = _get_qmin_qmax(4) py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) @@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 inner_k_tiles = 8 @@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self): torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) - subclass_quantizer = Int4WeightOnlyQATQuantizer( + qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizer( + ptq_quantizer = Int4WeightOnlyQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = [i.to(device).to(dtype) for i in m.example_inputs()] x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) class _MyQATQuantizer(TwoStepQuantizer): """ @@ -603,5 +624,127 @@ def test_qat_4w_embedding(self): converted = quantizer.convert(model) converted_out = converted(*x) + def test_fake_quantize_config(self): + """ + Test initialization and property setting of `FakeQuantizeConfig`. + """ + # basic configs + per_token_config = FakeQuantizeConfig(8, "per_token") + self.assertEqual(per_token_config.bit_width, 8) + self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN) + self.assertIsNone(per_token_config.group_size) + per_channel_config = FakeQuantizeConfig(4, "per_channel") + self.assertEqual(per_channel_config.bit_width, 4) + self.assertEqual(per_channel_config.granularity, QuantizationGranularity.PER_CHANNEL) + self.assertIsNone(per_channel_config.group_size) + + # initialize per_group config using only group size + per_group_config = FakeQuantizeConfig(4, group_size=32) + self.assertEqual(per_group_config.bit_width, 4) + self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP) + self.assertEqual(per_group_config.group_size, 32) + + # set granularity after initialization, should accept str as before + per_group_config.granularity = "per_token" + self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN) + + # set group_size after initialization, should also update granularity + per_group_config.group_size = 16 + self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP) + self.assertEqual(per_group_config.group_size, 16) + + # bad config1: no granularity or group size provided + with self.assertRaisesRegex(ValueError, "group_size or granularity must be set"): + FakeQuantizeConfig(8) + + # bad config2: 'per_group' but no group size + with self.assertRaisesRegex(ValueError, "no group_size was set"): + FakeQuantizeConfig(8, "per_group") + + # bad config3: group size was set but granularity was not 'per_group' + with self.assertRaisesRegex(ValueError, "group_size was set"): + FakeQuantizeConfig(8, "per_token", group_size=16) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_8da4w(self): + """ + Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. + """ + group_size = 128 + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=FakeQuantizeConfig(8, "per_token", symmetric=False), + weight_config=FakeQuantizeConfig(4, group_size=group_size), + ) + + def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. + """ + # activations + (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (qmin, qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) + + # weights + (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + (qmin, qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + return F.linear(x_fq, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_8da4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_4w(self): + """ + Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. + """ + group_size = 128 + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=group_size, + symmetric=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=None, + weight_config=weight_config, + ) + + def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int4 weight only fake quantization that simulates the tinygemm kernel. + """ + (qmin, qmax) = _get_qmin_qmax(4, symmetric=False) + (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 93717271bb..96071a5ce2 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,11 +4,75 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional import torch from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.quant_primitives import ZeroPointDomain + + +class QuantizationGranularity(Enum): + PER_CHANNEL = "per_channel" + PER_TOKEN = "per_token" + PER_GROUP = "per_group" + + +@dataclass +class FakeQuantizeConfig: + """ + Config for how to fake quantize weights or activations. + + args: + bit_width: number of bits to simulate during fake quantization + granularity: granularity of scales and zero points, one of: + 'per_token', 'per_channel', or 'per_group' + group_size: size of each group for 'per_group' granularity + symmetric: whether to use symmetric (default) or asymmetric quantization + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + dynamic: whether to use dynamic (defualt) or static scale and zero points + range_learning: whether to learn scale and zero points during training (coming soon) + """ + bit_width: int + granularity: Optional[QuantizationGranularity] = None + group_size: Optional[int] = None + symmetric: bool = True + scale_precision: torch.dtype = torch.float32 + zero_point_precision: torch.dtype = torch.int32 + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT + dynamic: bool = True + range_learning: bool = False + + def __post_init__(self): + """ + Verify that `group_size` and `granularity` are consistent. + """ + if self.group_size is None and self.granularity is None: + raise ValueError("At least one of group_size or granularity must be set") + if self.granularity == QuantizationGranularity.PER_GROUP and self.group_size is None: + raise ValueError("Granularity is 'per_group' but no group_size was set") + if self.granularity != QuantizationGranularity.PER_GROUP and self.group_size is not None: + if self.granularity is None: + self.granularity = QuantizationGranularity.PER_GROUP + else: + raise ValueError( + "Granularity is '%s' but group_size was set" % self.granularity.value + ) + self._initialized = True + + def __setattr__(self, name: str, value: Any): + """ + Support setting `granularity` by string and through `group_size`. + """ + if name == "group_size" and getattr(self, "_initialized", False): + super().__setattr__("granularity", QuantizationGranularity.PER_GROUP) + if name == "granularity" and isinstance(value, str): + value = QuantizationGranularity(value) + super().__setattr__(name, value) class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py new file mode 100644 index 0000000000..7c7e09f8b0 --- /dev/null +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, +) +from .api import ( + FakeQuantizeConfig, + QuantizationGranularity, +) +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +class FakeQuantizer(torch.nn.Module): + """ + Generic module for applying fake quantization to a tensor, as specified in the config. + """ + def __init__(self, config: FakeQuantizeConfig): + super().__init__() + self.config = config + self.enabled = True + self.scale: Optional[torch.Tensor] = None + self.zero_point: Optional[torch.Tensor] = None + + # TODO: support range learinng + if self.config.range_learning: + raise NotImplementedError("Range learning is not supported yet") + + def forward(self, x: torch.Tensor): + """ + Apply fake quantization to the tensor based on the bit-width, + granularity, symmetry, and other properties specified in the config. + """ + if not self.enabled: + return x + + if self.config.granularity == QuantizationGranularity.PER_TOKEN: + return self._per_token_forward(x) + elif self.config.granularity in [ + QuantizationGranularity.PER_CHANNEL, + QuantizationGranularity.PER_GROUP, + ]: + return self._per_channel_or_group_forward(x) + else: + raise ValueError("Unknown granularity %s" % self.config.granularity) + + def _per_token_forward(self, x: torch.Tensor): + """ + Perform per token fake quantization on the tensor. + """ + if self.config.symmetric: + raise NotImplementedError("Symmetric per token is not supported yet") + if self._should_compute_qparams(): + (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric( + x, self.config.scale_precision, self.config.zero_point_precision, + ) + qmin, qmax = _get_qmin_qmax(self.config.bit_width) + return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) + + def _per_channel_or_group_forward(self, x: torch.Tensor): + """ + Perform per channel or per group fake quantization on the tensor. + We express per channel using per group where the group size is the size + of the last dimension of the tensor. + """ + bit_width = self.config.bit_width + granularity = self.config.granularity + scale_precision = self.config.scale_precision + zero_point_precision = self.config.zero_point_precision + zero_point_domain = self.config.zero_point_domain + symmetric = self.config.symmetric + + # get group size + if granularity == QuantizationGranularity.PER_CHANNEL: + group_size = x.size()[-1] + elif granularity == QuantizationGranularity.PER_GROUP: + assert self.config.group_size is not None + group_size = self.config.group_size + else: + raise ValueError("Group size not defined for granularity %s" % granularity) + + # get scales and zero points + if self._should_compute_qparams(): + if symmetric: + (self.scale, self.zero_point) = get_group_qparams_symmetric( + x, bit_width, group_size, scale_precision, + ) + else: + (self.scale, self.zero_point) = get_groupwise_affine_qparams( + x, bit_width, group_size, scale_precision, + ) + self.zero_point = self.zero_point.to(zero_point_precision) + + qmin, qmax = _get_qmin_qmax(bit_width, symmetric) + return _fake_quantize_per_channel_group( + x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain, + ) + + def _should_compute_qparams(self) -> bool: + """ + Return whether we need to compute new scales and zero points. + """ + return self.config.dynamic or self.scale is None or self.zero_point is None diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py index 07276ba84c..32f560189b 100644 --- a/torchao/quantization/prototype/qat/linear.py +++ b/torchao/quantization/prototype/qat/linear.py @@ -21,6 +21,8 @@ from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer from .utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, @@ -29,6 +31,79 @@ ) +class FakeQuantizedLinear(torch.nn.Linear): + """ + General linear layer with fake quantized weights and activations. + + Specific fake quantization bit widths, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + activation_config = FakeQuantizeConfig( + bit_width=8, + granularity="per_token", + symmetric=False, + ) + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=8, + symmetric=True, + ) + fq_linear = FakeQuantizedLinear( + 16, 32, False, activation_config, weight_config, + ) + fq_linear(torch.randn(16)) + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + if bias: + raise NotImplementedError("bias not supported yet") + + # initialize activation fake quantizer + if activation_config is not None: + self.activation_fake_quantizer = FakeQuantizer(activation_config) + else: + self.activation_fake_quantizer = None + + # initialize weight fake quantizer + if weight_config is not None: + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) % group_size (%s) must be == 0" % + (in_features, group_size) + ) + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.activation_fake_quantizer is not None: + x = self.activation_fake_quantizer(x) + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.linear(x, w) + + # ========================================================= # | Linear int8 dynamic activations + int4 weight QAT | # ========================================================= @@ -77,42 +152,42 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) + self._convert_qat_linear_8da4w(model) return model - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + def _convert_qat_linear_8da4w(self, module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + config = child.weight_fake_quantizer.config + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=config.group_size, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + self._convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. @@ -133,63 +208,39 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: + activation_config = FakeQuantizeConfig( + bit_width=8, + granularity="per_token", + symmetric=False, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=groupsize, + symmetric=True, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) super().__init__( in_features, out_features, bias, + activation_config, + weight_config, device=device, dtype=precision, ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - def enable_8da4w_fake_quant(mod: torch.nn.Module): """ @@ -257,46 +308,45 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_4w(model) + self._convert_qat_linear_4w(model) return model - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): + def _convert_qat_linear_4w(self, module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + inner_k_tiles = child.inner_k_tiles + config = child.weight_fake_quantizer.config + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=config.group_size, + inner_k_tiles=inner_k_tiles, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, config.group_size, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + self._convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int4 fake quantized grouped per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, @@ -319,47 +369,36 @@ def __init__( precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.inner_k_tiles = inner_k_tiles + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=groupsize, + symmetric=False, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + zero_point_domain=ZeroPointDomain.FLOAT, + ) super().__init__( in_features, out_features, bias, + activation_config=None, + weight_config=weight_config, device=device, dtype=precision, ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - def enable_4w_fake_quant(mod: torch.nn.Module): """ diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 354475e655..8f2dd9d13f 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,7 +181,11 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _get_qmin_qmax(n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 +def _get_qmin_qmax(n_bit: int, symmetric: bool=True): + if symmetric: + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + else: + qmin = 0 + qmax = 2 ** n_bit - 1 return (qmin, qmax) From 75fcd2167b648bc05929e1a29f45bea53b9dfe94 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 12:44:32 -0700 Subject: [PATCH 03/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] --- torchao/prototype/awq/core.py | 9 ++- torchao/quantization/observer.py | 96 ++++-------------------- torchao/quantization/quant_primitives.py | 69 +++++++++++++++++ 3 files changed, 88 insertions(+), 86 deletions(-) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 77810a2e4a..15b0ec6c1e 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -9,10 +9,11 @@ from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import ( MappingType, + Granularity, ZeroPointDomain, ) from torchao.quantization.observer import ( - AffineQuantizedObserverBase, GranularityType + AffineQuantizedObserverBase, ) @@ -20,7 +21,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, bias: torch.Tensor, - quantization_granularity: GranularityType, + quantization_granularity: Granularity, mapping_type: MappingType, target_dtype: torch.dtype, n_validation_examples: int, @@ -40,7 +41,7 @@ def __init__(self, Args: weight: The weight tensor to be observed. bias: The bias tensor to be observed. - quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point input_dtype: The data type of the input tensor. mapping_type: Always set to asymmetric target_dtype: The target data type of the quantized tensor @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias - return observed_linear \ No newline at end of file + return observed_linear diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index bef4abe710..7e707953ba 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -3,12 +3,12 @@ _get_reduction_params, choose_qparams_affine_with_min_max, MappingType, + Granularity, ZeroPointDomain, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from typing import Tuple, Optional, Any from functools import partial import logging @@ -16,74 +16,6 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class GranularityType: - """ - Base class for representing the granularity of quantization. - - This class serves as a parent for specific granularity types used in - quantization operations, such as per-tensor or per-axis quantization. - """ - pass - -@dataclass(frozen=True) -class PerTensor(GranularityType): - """ - Represents per-tensor granularity in quantization. - - This granularity type calcualtes the quantization parameters - based off the entire tensor. - """ - pass - -@dataclass(frozen=True) -class PerAxis(GranularityType): - """ - Represents per-axis granularity in quantization. - - This granularity type calcualtes different quantization parameters - along a specified axis of the tensor. - - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. - - - Attributes: - axis (int): The axis along which reduction is performed. - """ - axis: int - -@dataclass(frozen=True) - -class PerGroup(GranularityType): - """ - Represents per-channel group granularity in quantization. - - This granularity type calcualtes different quantization parameters - for each group of elements. - - For example if the input tensor is shape [8, 16], and the group size is 4, then - the input tensor is reshaped to [64, 4] - quantization parameters are calculated for each group of 4 elements, - giving a total of 64 quantization parameters. - - Attributes: - group_size (int): The size of each quantization group - - """ - group_size: int - -class PerRow(GranularityType): - """ - Represents row-wise granularity in quantization. - - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). - """ - pass - # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): @@ -120,23 +52,23 @@ def _with_args(cls_or_self, *args, **kwargs): def get_block_size( - input_shape: Tuple[int, ...], granularity_type: GranularityType + input_shape: Tuple[int, ...], granularity: Granularity ) -> Tuple[int, ...]: """Get the block size based on the input shape and granularity type. Args: input_shape: The input tensor shape possibly more than 2 dimensions - granularity_type: The granularity type of the quantization + granularity: The granularity type of the quantization """ - if isinstance(granularity_type, PerTensor): + if isinstance(granularity, PerTensor): return input_shape - elif isinstance(granularity_type, PerAxis): + elif isinstance(granularity, PerAxis): block_size = list(input_shape) - block_size[granularity_type.axis] = 1 + block_size[granularity.axis] = 1 return tuple(block_size) - elif isinstance(granularity_type, PerRow): + elif isinstance(granularity, PerRow): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - raise ValueError(f"Unsupported GranularityType: {granularity_type}") + raise ValueError(f"Unsupported Granularity: {granularity}") ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -146,7 +78,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) Args: - `granularity_type` and `block_size`: The granularity of the quantization, + `granularity` and `block_size`: The granularity of the quantization, must specify at least one, if both are specified `block_size` takes precedence Current supported granularity type are `PerTensor` and `PerAxis` other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` @@ -158,7 +90,7 @@ def __init__( self, mapping_type: MappingType, target_dtype: torch.dtype, - granularity_type: GranularityType, + granularity: Granularity, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, @@ -168,11 +100,11 @@ def __init__( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ): super().__init__() - assert granularity_type is not None, "granularity_type is None" + assert granularity is not None, "granularity is None" self.mapping_type = mapping_type self.target_dtype = target_dtype - self.granularity_type = granularity_type + self.granularity = granularity self.quant_min = quant_min self.quant_max = quant_max self.eps = eps @@ -202,8 +134,8 @@ def forward(self, input: torch.Tensor): return input input_detached = input.detach() - assert self.granularity_type is not None, "granularity_type is None" - block_size = get_block_size(input_detached.shape, self.granularity_type) + assert self.granularity is not None, "granularity is None" + block_size = get_block_size(input_detached.shape, self.granularity) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input_detached.size() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index b1561e4cff..594bf8c42c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from enum import Enum, auto from typing import List, Optional, Tuple, Dict, Callable, Union import torch, math @@ -64,6 +65,74 @@ class ZeroPointDomain(Enum): INT = auto() FLOAT = auto() +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + pass + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calcualtes the quantization parameters + based off the entire tensor. + """ + pass + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calcualtes different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + + Attributes: + axis (int): The axis along which reduction is performed. + """ + axis: int + +@dataclass(frozen=True) + +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calcualtes different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + group_size: int + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + pass + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) From d671826b0607d931078434505cb8f96ab7056bfb Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 12:59:06 -0700 Subject: [PATCH 04/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] --- test/dtypes/test_affine_quantized_float.py | 8 ++++++-- test/quantization/test_observer.py | 4 ++-- torchao/_models/llama/eval.py | 4 ++-- torchao/_models/llama/generate.py | 2 +- torchao/prototype/awq/api.py | 2 +- torchao/quantization/README.md | 2 +- torchao/quantization/autoquant.py | 4 +++- torchao/quantization/observer.py | 3 +++ torchao/quantization/quant_api.py | 4 +++- tutorials/calibration_flow/awq_like.py | 4 ++-- tutorials/calibration_flow/gptq_like.py | 2 +- tutorials/calibration_flow/static_quant.py | 4 ++-- 12 files changed, 27 insertions(+), 16 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 621e3596e0..427a26969f 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -26,11 +26,15 @@ float8_weight_only, quantize_, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import ( float8_static_activation_float8_weight, ) -from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + PerRow, + PerTensor, +) random.seed(0) torch.manual_seed(0) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 8c8007871b..3cca97f076 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -11,14 +11,14 @@ from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerAxis, - PerTensor, ) from torchao.quantization.quant_api import ( insert_observers_, ) from torchao.quantization.quant_primitives import ( MappingType, + PerAxis, + PerTensor, ) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 6d46e45878..1655e450d7 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -24,9 +24,9 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model +from torchao.quantization.quant_primitives import PerRow, PerTensor from tokenizer import get_tokenizer import time @@ -255,4 +255,4 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, - ) \ No newline at end of file + ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 270054e130..1971ec094b 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,7 +216,7 @@ def main( float8_weight_only, float8_dynamic_activation_float8_weight, ) - from torchao.quantization.observer import PerTensor, PerRow + from torchao.quantization.quant_primitives import PerTensor, PerRow if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index e3a8827e2a..6827fe3915 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -3,11 +3,11 @@ from torchao.quantization.quant_primitives import ( MappingType, + PerGroup, ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import( diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index c936b7ef83..9d7b049470 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight -from torchao.quantization.observer import PerTensor +from torchao.quantization.quant_api import PerTensor quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) ``` diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a5568c4e17..8a02cccf29 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -13,11 +13,13 @@ from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( + PerAxis, + PerRow, + PerTensor, safe_int_mm, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.quantization.observer import PerAxis, PerTensor, PerRow from torchao.float8.inference import Float8MMConfig import torch.nn.functional as F diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 7e707953ba..d702b54f55 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -4,6 +4,9 @@ choose_qparams_affine_with_min_max, MappingType, Granularity, + PerAxis, + PerRow, + PerTensor, ZeroPointDomain, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c41425062..7584374687 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,6 +54,8 @@ from .quant_primitives import ( MappingType, + PerRow, + PerTensor, ZeroPointDomain, ) from .weight_only import WeightOnlyInt8QuantLinear @@ -71,7 +73,7 @@ ) from torchao.float8.inference import Float8MMConfig -from torchao.quantization.observer import PerTensor, PerRow, get_block_size +from torchao.quantization.observer import get_block_size logger = logging.getLogger(__name__) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 037dbae0f6..41a43bda56 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -22,11 +22,11 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, - PerAxis, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, + PerAxis, FP8_TYPES, ) diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index edb1b257ee..07dd2876a8 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -40,10 +40,10 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, fake_quantize_affine, ) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index f75485d3d5..d5469d4320 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -17,11 +17,11 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, - PerAxis, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, + PerAxis, FP8_TYPES, ) From d4332cb0d5fde91bfa7f9a942f851b30423d687f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 13:01:26 -0700 Subject: [PATCH 05/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] --- test/dtypes/test_affine_quantized_float.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 427a26969f..34f6bd2f15 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -30,10 +30,10 @@ float8_static_activation_float8_weight, ) from torchao.quantization.quant_primitives import ( - choose_qparams_affine, MappingType, PerRow, PerTensor, + choose_qparams_affine, ) random.seed(0) From dbad87814a8d655081296bec54e1af6837a0b039 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 13:18:30 -0700 Subject: [PATCH 06/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] --- test/quantization/test_observer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 3cca97f076..73a5a01e8a 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -42,7 +42,7 @@ def test_min_max_per_tensor_affine(self): obs = AffineQuantizedMinMaxObserver( MappingType.ASYMMETRIC, torch.uint8, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -54,7 +54,7 @@ def test_min_max_per_channel_affine(self): obs = AffineQuantizedMinMaxObserver( MappingType.ASYMMETRIC, torch.uint8, - granularity_type=PerAxis(axis=0), + granularity=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -68,7 +68,7 @@ def test_block_size_calc_success(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -87,7 +87,7 @@ def test_block_size_calc_success(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(1), + granularity=PerAxis(1), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -102,7 +102,7 @@ def test_block_size_row_errors(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(0), + granularity=PerAxis(0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -121,7 +121,7 @@ def test_block_size_row_errors(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(1), + granularity=PerAxis(1), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -149,7 +149,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_observer = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -159,7 +159,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): weight_observer = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, From 0153d6659a925d09b33bdc50a58fd7e34555c076 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 10 Oct 2024 17:51:06 -0700 Subject: [PATCH 07/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 8f48663732e1baa3f5c74426e387a2f113fc2a4e Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 13:15:49 -0700 Subject: [PATCH 08/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 4239d4713f9f18d25426f1c381d54b767533f06f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 13:22:18 -0700 Subject: [PATCH 09/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 75c83eff4bca00f7ff1e43f5d7ffcdfa849575dc Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 13:35:43 -0700 Subject: [PATCH 10/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From e08517c9829ad1fe27d90590b47d69a97c311b2b Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 13:44:35 -0700 Subject: [PATCH 11/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From d0d9573b935be8cd3a432c2bde191dc6a1579cac Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 14:01:40 -0700 Subject: [PATCH 12/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From f9a2f4c3d7b5428da45e5a12a196cace2c3ed7e8 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 14 Oct 2024 10:52:09 -0700 Subject: [PATCH 13/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 5b4feb0577a0ec86f6d85bb3d23ea7d24fdd3433 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 14 Oct 2024 11:02:17 -0700 Subject: [PATCH 14/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 756cb8ded86dadd83f7e0e4053b3d67aef292955 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 14 Oct 2024 13:26:29 -0700 Subject: [PATCH 15/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] From 622b6dfc1083ef48aad35340f3a1a9e672ff8927 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 14 Oct 2024 14:14:03 -0700 Subject: [PATCH 16/16] Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]