From 08e90957d5781cf2f507c2394ede15b3801e475c Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 21 Sep 2025 20:02:51 +0900 Subject: [PATCH 01/13] Summary: Introduce new tensor subclass API for int8 quantization with clearer interface. The main change can be summarized to the following: - Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling - New: Direct int8 tensor with qdata, scale, and zero_point attributes Test plan: test/quantization/quantize_/workflows/int8/test_int8_tensor.py Future plan: Implement block-wise quantization using `block_size` parameter --- docs/source/quantization_overview.rst | 4 +- .../workflows/int8/test_int8_tensor.py | 73 ++++++++++++ torchao/quantization/__init__.py | 2 + .../quantize_/workflows/__init__.py | 3 + .../quantize_/workflows/int8/int8_tensor.py | 106 ++++++++++++++++++ 5 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 test/quantization/quantize_/workflows/int8/test_int8_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/int8/int8_tensor.py diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst index f5c82bfe5f..df0a924b11 100644 --- a/docs/source/quantization_overview.rst +++ b/docs/source/quantization_overview.rst @@ -5,7 +5,7 @@ First we want to lay out the torchao stack:: Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor --------------------------------------------------------------------------------------------- Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize --------------------------------------------------------------------------------------------- @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma - scaled int4 - preshuffled (special format to optimize for loading) - float8 act + int4 weight dynamic quantization and int4 weight only quantization + * - Int8Tensor + - plain .. note:: We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py new file mode 100644 index 0000000000..bd7b003654 --- /dev/null +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import run_tests + +from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + Int8PlainInt8Tensor, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestInt8PlainInt8Tensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + torch.manual_seed(42) + self.weight_fp = torch.randn(4, 3, dtype=torch.float32) + self.input_fp = torch.randn(2, 3, dtype=torch.float32) + self.bias = torch.randn(4) + self.block_size = [4, 3] + + def test_creation_and_attributes(self): + """Test tensor creation, dtypes, and ranges""" + tensor = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.shape, (4, 3)) + self.assertEqual(tensor.qdata.dtype, torch.int8) + self.assertTrue( + torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) + ) + + def test_linear_operations(self): + """Test fp+int8 and int8+int8 linear ops with quantization error check""" + weight_q8 = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size) + input_q8 = Int8PlainInt8Tensor.from_hp(self.input_fp, self.block_size) + + reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) + result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias) + result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias) + + self.assertEqual(result_fp.shape, reference.shape) + self.assertEqual(result_q8.shape, reference.shape) + self.assertTrue(compute_error(result_fp, reference) > 10) + self.assertTrue(compute_error(result_q8, reference) > 10) + + def test_error_handling_and_dequant(self): + """Test input validation and dequantization accuracy""" + # Test 1D tensor validation + with self.assertRaises((AssertionError, ValueError, RuntimeError)): + Int8PlainInt8Tensor.from_hp(torch.randn(5), [1]) + + # Test wrong block_size validation + with self.assertRaises((AssertionError, ValueError, RuntimeError)): + Int8PlainInt8Tensor.from_hp(self.weight_fp, [1]) + + # Test dequantization with exact values + test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) + tensor = Int8PlainInt8Tensor.from_hp(test_data, [1, 1]) + + dequantized = tensor.dequantize() + self.assertEqual(dequantized.shape, test_data.shape) + self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b32868b684..9fa6126e69 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -95,6 +95,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8PlainInt8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -168,6 +169,7 @@ "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", + "Int8PlainInt8Tensor", "Float8Tensor", "Int4OpaqueTensor", # smooth quant - subject to change diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 229c94c73a..f5fc2d64e7 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,5 @@ +from int8.int8_tensor import Int8PlainInt8Tensor + from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -36,6 +38,7 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", + "Int8PlainInt8Tensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py new file mode 100644 index 0000000000..f9ac85fba4 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Int8PlainInt8Tensor"] + +aten = torch.ops.aten + + +# TODO: Implement block-wise quantization using block_size +class Int8PlainInt8Tensor(TorchAOBaseTensor): + """ + int8 quantized tensor with plain layout + + Tensor Attributes: + qdata: (N, K) int8 quantized weight data + scale: scale factors for dequantization + zero_point: zero points for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + shape: original tensor shape + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size"] + + def __new__(cls, qdata, scale, zero_point, block_size, shape): + kwargs = {"device": qdata.device, "dtype": scale.dtype, "requires_grad": False} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, qdata, scale, zero_point, block_size, shape): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + + @classmethod + def from_hp(cls, w: torch.Tensor, block_size: list[int]): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + # Rounding function from high precision dtype + scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 + scale = scale.clamp(min=1e-6) + + int_data = torch.round(w / scale).clamp(-128, 127).to(torch.int8) + + return cls( + int_data, + scale.squeeze(-1), + torch.zeros_like(scale.squeeze(-1), dtype=torch.int8), + block_size, + w.shape, + ) + + +implements = Int8PlainInt8Tensor.implements + + +@implements([aten.dequantize.self]) +def _(func, types, args, kwargs): + """dequantization: int8 -> float""" + tensor = args[0] + return ( + tensor.qdata.to(tensor.scale.dtype) + - tensor.zero_point.to(tensor.scale.dtype).unsqueeze(1) + ) * tensor.scale.unsqueeze(1) + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + """quantization: float -> int8""" + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + if isinstance(input_tensor, Int8PlainInt8Tensor): + # INT8 × INT8 + x_int32 = input_tensor.qdata.to(torch.int32) + w_int32 = weight_tensor.qdata.to(torch.int32).t() + + result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32) + scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0) + result = result.to(scale.dtype) * scale + result = result.view(*input_tensor.shape[:-1], -1) + else: + # FP × INT8 + result = torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), None + ) + + return result + bias if bias is not None else result + + +Int8PlainInt8Tensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8PlainInt8Tensor]) From db23cf3c25f8125443db7dad82a6fbd0b3e652ac Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 23 Sep 2025 02:45:28 +0900 Subject: [PATCH 02/13] rename for clearly: Int8PlainInt8Tensor -> Int8Tensor --- .../quantize_/workflows/int8/test_int8_tensor.py | 14 +++++++------- torchao/quantization/__init__.py | 4 ++-- .../quantization/quantize_/workflows/__init__.py | 4 ++-- .../quantize_/workflows/int8/int8_tensor.py | 12 ++++++------ 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index bd7b003654..7eb5278e53 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import run_tests from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( - Int8PlainInt8Tensor, + Int8Tensor, ) from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @@ -28,7 +28,7 @@ def setUp(self): def test_creation_and_attributes(self): """Test tensor creation, dtypes, and ranges""" - tensor = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size) + tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) self.assertEqual(tensor.shape, (4, 3)) self.assertEqual(tensor.qdata.dtype, torch.int8) @@ -38,8 +38,8 @@ def test_creation_and_attributes(self): def test_linear_operations(self): """Test fp+int8 and int8+int8 linear ops with quantization error check""" - weight_q8 = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size) - input_q8 = Int8PlainInt8Tensor.from_hp(self.input_fp, self.block_size) + weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) + input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size) reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias) @@ -54,15 +54,15 @@ def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" # Test 1D tensor validation with self.assertRaises((AssertionError, ValueError, RuntimeError)): - Int8PlainInt8Tensor.from_hp(torch.randn(5), [1]) + Int8Tensor.from_hp(torch.randn(5), [1]) # Test wrong block_size validation with self.assertRaises((AssertionError, ValueError, RuntimeError)): - Int8PlainInt8Tensor.from_hp(self.weight_fp, [1]) + Int8Tensor.from_hp(self.weight_fp, [1]) # Test dequantization with exact values test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) - tensor = Int8PlainInt8Tensor.from_hp(test_data, [1, 1]) + tensor = Int8Tensor.from_hp(test_data, [1, 1]) dequantized = tensor.dequantize() self.assertEqual(dequantized.shape, test_data.shape) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 9fa6126e69..81ac401bae 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -95,7 +95,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, - Int8PlainInt8Tensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -169,7 +169,7 @@ "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", - "Int8PlainInt8Tensor", + "Int8Tensor", "Float8Tensor", "Int4OpaqueTensor", # smooth quant - subject to change diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index f5fc2d64e7..83062b4787 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,4 +1,4 @@ -from int8.int8_tensor import Int8PlainInt8Tensor +from int8.int8_tensor import Int8Tensor from .float8.float8_tensor import ( Float8Tensor, @@ -38,7 +38,7 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", - "Int8PlainInt8Tensor", + "Int8Tensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index f9ac85fba4..b285207a7f 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -9,13 +9,13 @@ from torchao.utils import TorchAOBaseTensor -__all__ = ["Int8PlainInt8Tensor"] +__all__ = ["Int8Tensor"] aten = torch.ops.aten # TODO: Implement block-wise quantization using block_size -class Int8PlainInt8Tensor(TorchAOBaseTensor): +class Int8Tensor(TorchAOBaseTensor): """ int8 quantized tensor with plain layout @@ -62,7 +62,7 @@ def from_hp(cls, w: torch.Tensor, block_size: list[int]): ) -implements = Int8PlainInt8Tensor.implements +implements = Int8Tensor.implements @implements([aten.dequantize.self]) @@ -84,7 +84,7 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) - if isinstance(input_tensor, Int8PlainInt8Tensor): + if isinstance(input_tensor, Int8Tensor): # INT8 × INT8 x_int32 = input_tensor.qdata.to(torch.int32) w_int32 = weight_tensor.qdata.to(torch.int32).t() @@ -102,5 +102,5 @@ def _(func, types, args, kwargs): return result + bias if bias is not None else result -Int8PlainInt8Tensor.__module__ = "torchao.quantization" -torch.serialization.add_safe_globals([Int8PlainInt8Tensor]) +Int8Tensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8Tensor]) From b861dbc2aad9377ea4fe6bc6f5c773962874cd00 Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 23 Sep 2025 15:07:06 +0900 Subject: [PATCH 03/13] add flags for static/dynamic quant --- .../workflows/int8/test_int8_tensor.py | 19 ++- .../common/quantize_tensor_kwargs.py | 8 ++ .../quantize_/workflows/__init__.py | 7 +- .../quantize_/workflows/int8/int8_tensor.py | 111 ++++++++++++++++-- 4 files changed, 129 insertions(+), 16 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 7eb5278e53..f1feaba62a 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -11,13 +11,14 @@ from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( Int8Tensor, + QuantizeTensorToInt8Kwargs, ) from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -class TestInt8PlainInt8Tensor(TorchAOIntegrationTestCase): +class TestInt8Tensor(TorchAOIntegrationTestCase): def setUp(self): super().setUp() torch.manual_seed(42) @@ -50,6 +51,20 @@ def test_linear_operations(self): self.assertTrue(compute_error(result_fp, reference) > 10) self.assertTrue(compute_error(result_q8, reference) > 10) + def test_dynamic_quantization(self): + weight_q8_dynamic = Int8Tensor.from_hp( + self.weight_fp, + self.block_size, + act_quant_kwargs=QuantizeTensorToInt8Kwargs(), + ) + + reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) + result_dynamic = torch.nn.functional.linear( + self.input_fp, weight_q8_dynamic, self.bias + ) + + self.assertEqual(result_dynamic.shape, reference.shape) + def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" # Test 1D tensor validation @@ -64,7 +79,7 @@ def test_error_handling_and_dequant(self): test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) tensor = Int8Tensor.from_hp(test_data, [1, 1]) - dequantized = tensor.dequantize() + dequantized = torch.ops.aten.dequantize.self(tensor) self.assertEqual(dequantized.shape, test_data.shape) self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..2c3c6bcab6 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.hp_value_ub, quant_kwargs.kernel_preference, ) + elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + quant_kwargs.block_size or [1, tensor.shape[-1]], + kernel_preference=quant_kwargs.kernel_preference, + ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 83062b4787..db0e5de6c0 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,5 +1,3 @@ -from int8.int8_tensor import Int8Tensor - from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -22,6 +20,10 @@ Int4Tensor, ) from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, ) @@ -39,6 +41,7 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index b285207a7f..62ce4f3e2f 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -4,16 +4,36 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional import torch +from torchao.quantization.quantize_.common import ( + KernelPreference, + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) from torchao.utils import TorchAOBaseTensor -__all__ = ["Int8Tensor"] +__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] aten = torch.ops.aten +@dataclass +class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating int8 tensor (either activation or weight) + + Args: + kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. + block_size (Optional[list[int]]): block size for quantization granularity + """ + + kernel_preference: KernelPreference = KernelPreference.AUTO + block_size: Optional[list[int]] = None + + # TODO: Implement block-wise quantization using block_size class Int8Tensor(TorchAOBaseTensor): """ @@ -27,23 +47,70 @@ class Int8Tensor(TorchAOBaseTensor): Non-Tensor Attributes: block_size: block size for quantization granularity shape: original tensor shape + act_quant_kwargs: flags for static/dynamic activation quantization + kernel_preference: kernel preference for operations """ tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size"] - - def __new__(cls, qdata, scale, zero_point, block_size, shape): - kwargs = {"device": qdata.device, "dtype": scale.dtype, "requires_grad": False} + optional_tensor_attribute_names = [ + "act_quant_kwargs", + "kernel_preference", + "dtype", + ] + + def __new__( + cls, + qdata, + scale, + zero_point, + block_size, + shape, + act_quant_kwargs=None, + kernel_preference=KernelPreference.AUTO, + dtype=None, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) - def __init__(self, qdata, scale, zero_point, block_size, shape): + def __init__( + self, + qdata, + scale, + zero_point, + block_size, + shape, + act_quant_kwargs=None, + kernel_preference=KernelPreference.AUTO, + dtype=None, + ): + super().__init__() self.qdata = qdata self.scale = scale self.zero_point = zero_point self.block_size = block_size + self.act_quant_kwargs = act_quant_kwargs + self.kernel_preference = kernel_preference + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " + f"{self.zero_point=}, {self.block_size=}, {self.kernel_preference=}, " + f"{self.shape=}, {self.device=}, {self.dtype=})" + ) @classmethod - def from_hp(cls, w: torch.Tensor, block_size: list[int]): + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + ): if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") @@ -59,8 +126,18 @@ def from_hp(cls, w: torch.Tensor, block_size: list[int]): torch.zeros_like(scale.squeeze(-1), dtype=torch.int8), block_size, w.shape, + act_quant_kwargs=act_quant_kwargs, + kernel_preference=kernel_preference, + dtype=w.dtype, ) + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize int8 tensor to floating point""" + dtype = output_dtype or self.dtype or self.scale.dtype + return ( + self.qdata.to(dtype) - self.zero_point.to(dtype).unsqueeze(1) + ) * self.scale.to(dtype).unsqueeze(1) + implements = Int8Tensor.implements @@ -69,10 +146,10 @@ def from_hp(cls, w: torch.Tensor, block_size: list[int]): def _(func, types, args, kwargs): """dequantization: int8 -> float""" tensor = args[0] + dtype = tensor.dtype or tensor.scale.dtype return ( - tensor.qdata.to(tensor.scale.dtype) - - tensor.zero_point.to(tensor.scale.dtype).unsqueeze(1) - ) * tensor.scale.unsqueeze(1) + tensor.qdata.to(dtype) - tensor.zero_point.to(dtype).unsqueeze(1) + ) * tensor.scale.to(dtype).unsqueeze(1) @implements([torch.nn.functional.linear, aten.linear.default]) @@ -84,8 +161,18 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) + assert isinstance(weight_tensor, Int8Tensor), ( + f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" + ) + + # Dynamic activation quantization if enabled + if weight_tensor.act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, weight_tensor.act_quant_kwargs + ) + if isinstance(input_tensor, Int8Tensor): - # INT8 × INT8 + # INT8 × INT8 (dynamic) x_int32 = input_tensor.qdata.to(torch.int32) w_int32 = weight_tensor.qdata.to(torch.int32).t() @@ -94,7 +181,7 @@ def _(func, types, args, kwargs): result = result.to(scale.dtype) * scale result = result.view(*input_tensor.shape[:-1], -1) else: - # FP × INT8 + # FP × INT8 (static) result = torch.nn.functional.linear( input_tensor, weight_tensor.dequantize(), None ) @@ -103,4 +190,4 @@ def _(func, types, args, kwargs): Int8Tensor.__module__ = "torchao.quantization" -torch.serialization.add_safe_globals([Int8Tensor]) +torch.serialization.add_safe_globals([Int8Tensor, QuantizeTensorToInt8Kwargs]) From 93835507aaf7003e353000ff28275074b4ffb4b2 Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 25 Sep 2025 01:33:54 +0900 Subject: [PATCH 04/13] update static/dynamic quantization workflows --- .../quantize_/workflows/int8/int8_tensor.py | 77 +++++++++++++++---- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 62ce4f3e2f..3cd22aba09 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -27,6 +27,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. + TODO: Implement flags for kernel preference, same as QuantizeTensorToFloat8Kwargs block_size (Optional[list[int]]): block size for quantization granularity """ @@ -165,26 +166,68 @@ def _(func, types, args, kwargs): f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" ) - # Dynamic activation quantization if enabled - if weight_tensor.act_quant_kwargs is not None: - input_tensor = _choose_quant_func_and_quantize_tensor( - input_tensor, weight_tensor.act_quant_kwargs + if isinstance(input_tensor, Int8Tensor): + # INT8 × INT8 (static) + x_vals_int8 = input_tensor.qdata + x_scales = input_tensor.scale + w_vals_int8_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + + # Cast fp16 scale to float to avoid overflow in y_dot_int32 + intermediate_dtype = ( + torch.float if x_scales_dtype == torch.half else x_scales_dtype ) - if isinstance(input_tensor, Int8Tensor): - # INT8 × INT8 (dynamic) - x_int32 = input_tensor.qdata.to(torch.int32) - w_int32 = weight_tensor.qdata.to(torch.int32).t() - - result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32) - scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0) - result = result.to(scale.dtype) * scale - result = result.view(*input_tensor.shape[:-1], -1) - else: - # FP × INT8 (static) - result = torch.nn.functional.linear( - input_tensor, weight_tensor.dequantize(), None + # First apply input scaling to avoid overflow + y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32)) + y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape(-1, 1).to( + intermediate_dtype ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + # Then apply weight scaling + result = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + result = result.to(input_tensor.dtype) + + else: + if weight_tensor.act_quant_kwargs is not None: + # INT8 × INT8 (dynamic) + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, weight_tensor.act_quant_kwargs + ) + + x_vals_int8 = input_tensor.qdata + x_scales = input_tensor.scale + w_vals_int8_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + + # Cast fp16 scale to float to avoid overflow in y_dot_int32 + intermediate_dtype = ( + torch.float if x_scales_dtype == torch.half else x_scales_dtype + ) + y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32)) + y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape( + -1, 1 + ).to(intermediate_dtype) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + result = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + result = result.to(input_tensor.dtype) + else: + # FP × INT8 (weight-only) + result = torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), None + ) return result + bias if bias is not None else result From 2c84ba4c1012a785199818b1eeb520facdbf6339 Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 25 Sep 2025 01:57:47 +0900 Subject: [PATCH 05/13] add kernel preference unit test --- .../quantize_/workflows/int8/test_int8_tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index f1feaba62a..f664fda565 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -7,8 +7,10 @@ import unittest import torch +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( Int8Tensor, QuantizeTensorToInt8Kwargs, @@ -37,6 +39,18 @@ def test_creation_and_attributes(self): torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) + @common_utils.parametrize( + "kernel_preference", + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + ) + def test_kernel_preference(self, kernel_preference): + """Test Int8Tensor with different kernels""" + tensor = Int8Tensor.from_hp( + self.weight_fp, self.block_size, kernel_preference=kernel_preference + ) + + self.assertEqual(tensor.kernel_preference, kernel_preference) + def test_linear_operations(self): """Test fp+int8 and int8+int8 linear ops with quantization error check""" weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) From 8ddddd3bc6d61094c7ae04747a09e10bf3463edc Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 25 Sep 2025 01:57:47 +0900 Subject: [PATCH 06/13] add kernel preference unit test --- .../workflows/int8/test_int8_tensor.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index f664fda565..821e916d76 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -9,6 +9,7 @@ import torch from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torch._inductor.utils import run_and_get_code from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( @@ -79,6 +80,31 @@ def test_dynamic_quantization(self): self.assertEqual(result_dynamic.shape, reference.shape) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_expected_kernel_operations(self): + """Test Int8Tensor with FBGEMM kernels""" + + # Setup model with Int8Tensor + weight_q8 = Int8Tensor.from_hp( + self.weight_fp, + self.block_size, + kernel_preference=KernelPreference.FBGEMM + ) + + def model(x): + return torch.nn.functional.linear(x, weight_q8, self.bias) + + compiled_model = torch.compile(model) + + output, code = run_and_get_code(compiled_model, self.input_fp) + + self.assertEqual(output.shape, (2, 4)) + self.assertTrue(len(code) > 0, "Should generate some compiled code") + + # Test dequantization kernel + dequant_output = torch.ops.aten.dequantize.self(weight_q8) + self.assertEqual(dequant_output.shape, self.weight_fp.shape) + def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" # Test 1D tensor validation From b5cb3c84c26de46cb65b6350556e6e4fff76fc3e Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 25 Sep 2025 04:00:06 +0900 Subject: [PATCH 07/13] fix missing attribute --- torchao/quantization/quantize_/workflows/int8/int8_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 3cd22aba09..a130b362ee 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -53,7 +53,7 @@ class Int8Tensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "scale", "zero_point"] - tensor_attribute_names = ["block_size"] + tensor_attribute_names = ["block_size", "_shape"] optional_tensor_attribute_names = [ "act_quant_kwargs", "kernel_preference", @@ -94,6 +94,7 @@ def __init__( self.scale = scale self.zero_point = zero_point self.block_size = block_size + self._shape = shape self.act_quant_kwargs = act_quant_kwargs self.kernel_preference = kernel_preference From 9a51cae97ab15814d15b6bcef096920c591885ac Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 28 Sep 2025 10:10:11 +0900 Subject: [PATCH 08/13] remove kernel preference args --- .../workflows/int8/test_int8_tensor.py | 40 ------------------- .../quantize_/workflows/int8/int8_tensor.py | 13 +----- 2 files changed, 1 insertion(+), 52 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 821e916d76..f1feaba62a 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -7,11 +7,8 @@ import unittest import torch -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests -from torch._inductor.utils import run_and_get_code -from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( Int8Tensor, QuantizeTensorToInt8Kwargs, @@ -40,18 +37,6 @@ def test_creation_and_attributes(self): torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) - @common_utils.parametrize( - "kernel_preference", - [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], - ) - def test_kernel_preference(self, kernel_preference): - """Test Int8Tensor with different kernels""" - tensor = Int8Tensor.from_hp( - self.weight_fp, self.block_size, kernel_preference=kernel_preference - ) - - self.assertEqual(tensor.kernel_preference, kernel_preference) - def test_linear_operations(self): """Test fp+int8 and int8+int8 linear ops with quantization error check""" weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) @@ -80,31 +65,6 @@ def test_dynamic_quantization(self): self.assertEqual(result_dynamic.shape, reference.shape) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_expected_kernel_operations(self): - """Test Int8Tensor with FBGEMM kernels""" - - # Setup model with Int8Tensor - weight_q8 = Int8Tensor.from_hp( - self.weight_fp, - self.block_size, - kernel_preference=KernelPreference.FBGEMM - ) - - def model(x): - return torch.nn.functional.linear(x, weight_q8, self.bias) - - compiled_model = torch.compile(model) - - output, code = run_and_get_code(compiled_model, self.input_fp) - - self.assertEqual(output.shape, (2, 4)) - self.assertTrue(len(code) > 0, "Should generate some compiled code") - - # Test dequantization kernel - dequant_output = torch.ops.aten.dequantize.self(weight_q8) - self.assertEqual(dequant_output.shape, self.weight_fp.shape) - def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" # Test 1D tensor validation diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index a130b362ee..21a4eb66bf 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -10,7 +10,6 @@ import torch from torchao.quantization.quantize_.common import ( - KernelPreference, QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, ) @@ -26,12 +25,9 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): """Tensor kwargs for creating int8 tensor (either activation or weight) Args: - kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. - TODO: Implement flags for kernel preference, same as QuantizeTensorToFloat8Kwargs block_size (Optional[list[int]]): block size for quantization granularity """ - kernel_preference: KernelPreference = KernelPreference.AUTO block_size: Optional[list[int]] = None @@ -49,14 +45,12 @@ class Int8Tensor(TorchAOBaseTensor): block_size: block size for quantization granularity shape: original tensor shape act_quant_kwargs: flags for static/dynamic activation quantization - kernel_preference: kernel preference for operations """ tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "_shape"] optional_tensor_attribute_names = [ "act_quant_kwargs", - "kernel_preference", "dtype", ] @@ -68,7 +62,6 @@ def __new__( block_size, shape, act_quant_kwargs=None, - kernel_preference=KernelPreference.AUTO, dtype=None, ): kwargs = { @@ -86,7 +79,6 @@ def __init__( block_size, shape, act_quant_kwargs=None, - kernel_preference=KernelPreference.AUTO, dtype=None, ): super().__init__() @@ -96,12 +88,11 @@ def __init__( self.block_size = block_size self._shape = shape self.act_quant_kwargs = act_quant_kwargs - self.kernel_preference = kernel_preference def __repr__(self): return ( f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " - f"{self.zero_point=}, {self.block_size=}, {self.kernel_preference=}, " + f"{self.zero_point=}, {self.block_size=}, " f"{self.shape=}, {self.device=}, {self.dtype=})" ) @@ -111,7 +102,6 @@ def from_hp( w: torch.Tensor, block_size: list[int], act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - kernel_preference: KernelPreference = KernelPreference.AUTO, ): if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") @@ -129,7 +119,6 @@ def from_hp( block_size, w.shape, act_quant_kwargs=act_quant_kwargs, - kernel_preference=kernel_preference, dtype=w.dtype, ) From c53dad03d0f6f0779c31876ab1232464b93ea566 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 28 Sep 2025 23:48:38 +0900 Subject: [PATCH 09/13] link new API with old API using version 2 --- .../workflows/int8/test_int8_tensor.py | 101 +++++++++++++++++- torchao/quantization/quant_api.py | 80 +++++++++----- .../quantize_/workflows/int8/int8_tensor.py | 2 + 3 files changed, 152 insertions(+), 31 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index f1feaba62a..43fc1324ff 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -4,11 +4,19 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy import unittest +from contextlib import nullcontext +from typing import Tuple import torch -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal import common_utils +from torchao.quantization import ( + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + quantize_, +) from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( Int8Tensor, QuantizeTensorToInt8Kwargs, @@ -17,7 +25,46 @@ from torchao.testing.utils import TorchAOIntegrationTestCase +# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + has_bias=False, + dtype=None, + device=None, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + # Note: tinygemm kernel only uses bfloat16 inputs + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests class TestInt8Tensor(TorchAOIntegrationTestCase): def setUp(self): super().setUp() @@ -37,6 +84,56 @@ def test_creation_and_attributes(self): torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("compile", [False, True]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_int8_linear_variants( + self, + dtype: torch.dtype, + compile: bool, + sizes: Tuple, + config, + ): + error_message = None + + error_context = ( + self.assertRaisesRegex(AssertionError, error_message) + if error_message + else nullcontext() + ) + + with error_context: + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + # Create a linear layer + m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") + m_q = copy.deepcopy(m) + + # Quantize + quantize_(m_q, config) + + output_original = m(input_tensor) + output_quantized = m_q(input_tensor) + + error = compute_error(output_original, output_quantized) + assert compute_error(output_original, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) + def test_linear_operations(self): """Test fp+int8 and int8+int8 linear ops with quantization error check""" weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) @@ -85,4 +182,4 @@ def test_error_handling_and_dequant(self): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15caddcadc..031dc5cced 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -78,6 +78,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxPackingFormat, IntxUnpackedToInt8Tensor, @@ -1352,10 +1353,12 @@ class Int8WeightOnlyConfig(AOBaseConfig): Otherwise, applies per-group quantization with the specified group size. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. + version: int = 2 - Version of the config to use. Version 1 uses AffineQuantization for quantization, """ group_size: Optional[int] = None set_inductor_config: bool = True + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") @@ -1366,22 +1369,30 @@ def __post_init__(self): def _int8_weight_only_quantize_tensor(weight, config): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - group_size = config.group_size - if group_size is None: - group_size = weight.shape[-1] - block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + ) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + group_size = config.group_size + if group_size is None: + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + block_size = [weight.shape[0], weight.shape[1]] + new_weight = Int8Tensor.from_hp(weight, block_size=block_size) return new_weight @@ -1509,12 +1520,14 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): in original precision during decode operations. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor """ layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False set_inductor_config: bool = True + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once( @@ -1562,19 +1575,28 @@ def get_weight_block_size(x): else: input_quant_func = _int8_asymm_per_token_quant - block_size = get_weight_block_size(weight) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, - ) - new_weight = to_linear_activation_quantized(new_weight, input_quant_func) - return new_weight + if config.version == 1: + block_size = get_weight_block_size(weight) + quantized_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + else: + quantized_weight = Int8Tensor.from_hp( + weight, + block_size=get_weight_block_size(weight), + ) + + return quantized_weight @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 21a4eb66bf..1c8dcb70d2 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -29,6 +29,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): """ block_size: Optional[list[int]] = None + kernel_preference: Optional[str] = None # TODO: Implement block-wise quantization using block_size @@ -102,6 +103,7 @@ def from_hp( w: torch.Tensor, block_size: list[int], act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + kernel_preference: Optional[str] = None, ): if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") From d300b0293ec1c5f1d99e4b6285feede32e81ebab Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 30 Sep 2025 15:06:51 +0900 Subject: [PATCH 10/13] add granularity, block size support --- .../workflows/int8/test_int8_tensor.py | 72 ++++++++- torchao/quantization/quant_api.py | 7 +- .../quantize_/workflows/int8/int8_tensor.py | 142 +++++++++++------- 3 files changed, 162 insertions(+), 59 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 43fc1324ff..e951793bf5 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -15,6 +15,8 @@ from torchao.quantization import ( Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + PerRow, + PerTensor, quantize_, ) from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( @@ -70,7 +72,7 @@ def setUp(self): super().setUp() torch.manual_seed(42) self.weight_fp = torch.randn(4, 3, dtype=torch.float32) - self.input_fp = torch.randn(2, 3, dtype=torch.float32) + self.input_fp = torch.randn(4, 3, dtype=torch.float32) self.bias = torch.randn(4) self.block_size = [4, 3] @@ -162,6 +164,72 @@ def test_dynamic_quantization(self): self.assertEqual(result_dynamic.shape, reference.shape) + @unittest.skip("granularity parameter not supported in current API") + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_preserves_aliasing(self, granularity): + config = Int8DynamicActivationInt8WeightConfig( + granularity=granularity, version=2 + ) + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16]) + def test_slice(self, config, device, dtype): + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, config) + weight1 = dummy.weight.clone().narrow(0, 0, 64) + weight2 = dummy.weight.clone().narrow(1, 0, 128) + self.assertEqual( + weight1.qdata, + dummy.weight.qdata.narrow(0, 0, 64), + ) + self.assertEqual( + weight2.qdata, + dummy.weight.qdata.narrow(1, 0, 128), + ) + + # check for sliced weight, before and after int8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 20, f"sqnr: {sqnr}") + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") + def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" # Test 1D tensor validation @@ -174,7 +242,7 @@ def test_error_handling_and_dequant(self): # Test dequantization with exact values test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) - tensor = Int8Tensor.from_hp(test_data, [1, 1]) + tensor = Int8Tensor.from_hp(test_data, [1, 2]) dequantized = torch.ops.aten.dequantize.self(tensor) self.assertEqual(dequantized.shape, test_data.shape) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 031dc5cced..885da6948a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1353,7 +1353,7 @@ class Int8WeightOnlyConfig(AOBaseConfig): Otherwise, applies per-group quantization with the specified group size. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. - version: int = 2 - Version of the config to use. Version 1 uses AffineQuantization for quantization, + version - Version of the config to use. Version 1 uses AffineQuantization for quantization, """ group_size: Optional[int] = None @@ -1371,7 +1371,7 @@ def __post_init__(self): def _int8_weight_only_quantize_tensor(weight, config): if config.version == 1: warnings.warn( - "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" ) mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -1576,6 +1576,9 @@ def get_weight_block_size(x): input_quant_func = _int8_asymm_per_token_quant if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) block_size = get_weight_block_size(weight) quantized_weight = to_affine_quantized_intx( weight, diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 1c8dcb70d2..ff26cf2952 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -8,7 +8,13 @@ from typing import Optional import torch +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + quantize_affine, +) from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, @@ -32,7 +38,6 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): kernel_preference: Optional[str] = None -# TODO: Implement block-wise quantization using block_size class Int8Tensor(TorchAOBaseTensor): """ int8 quantized tensor with plain layout @@ -108,16 +113,29 @@ def from_hp( if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") - # Rounding function from high precision dtype - scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 - scale = scale.clamp(min=1e-6) + scale, zero_point = choose_qparams_affine( + input=w, + mapping_type=MappingType.SYMMETRIC, + block_size=tuple(block_size), + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w.dtype, + zero_point_dtype=torch.int8, + ) - int_data = torch.round(w / scale).clamp(-128, 127).to(torch.int8) + int_data = quantize_affine( + w, + block_size=tuple(block_size), + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) return cls( int_data, - scale.squeeze(-1), - torch.zeros_like(scale.squeeze(-1), dtype=torch.int8), + scale, + zero_point, block_size, w.shape, act_quant_kwargs=act_quant_kwargs, @@ -127,9 +145,18 @@ def from_hp( def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize int8 tensor to floating point""" dtype = output_dtype or self.dtype or self.scale.dtype - return ( - self.qdata.to(dtype) - self.zero_point.to(dtype).unsqueeze(1) - ) * self.scale.to(dtype).unsqueeze(1) + + qdata_fp = self.qdata.to(dtype) + scale = self.scale.to(dtype) + zero_point = self.zero_point.to(dtype) + + # Reshape 1D scale/zero_point to [N, 1] for broadcasting with [N, K] qdata + if scale.ndim == 1: + scale = scale.unsqueeze(1) + if zero_point.ndim == 1: + zero_point = zero_point.unsqueeze(1) + + return (qdata_fp - zero_point) * scale implements = Int8Tensor.implements @@ -138,11 +165,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @implements([aten.dequantize.self]) def _(func, types, args, kwargs): """dequantization: int8 -> float""" - tensor = args[0] - dtype = tensor.dtype or tensor.scale.dtype - return ( - tensor.qdata.to(dtype) - tensor.zero_point.to(dtype).unsqueeze(1) - ) * tensor.scale.to(dtype).unsqueeze(1) + return args[0].dequantize() @implements([torch.nn.functional.linear, aten.linear.default]) @@ -158,8 +181,14 @@ def _(func, types, args, kwargs): f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" ) - if isinstance(input_tensor, Int8Tensor): - # INT8 × INT8 (static) + if weight_tensor.act_quant_kwargs is not None: + # INT8 × INT8 (dynamic) + # Quantize input if it's not already quantized + if not isinstance(input_tensor, Int8Tensor): + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, weight_tensor.act_quant_kwargs + ) + x_vals_int8 = input_tensor.qdata x_scales = input_tensor.scale w_vals_int8_t = weight_tensor.qdata.contiguous().t() @@ -168,60 +197,63 @@ def _(func, types, args, kwargs): tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in y_dot_int32 + # Cast fp16 scale to float intermediate_dtype = ( torch.float if x_scales_dtype == torch.half else x_scales_dtype ) - - # First apply input scaling to avoid overflow - y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32)) - y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape(-1, 1).to( + y_dot_int64 = torch.mm(tmp.to(torch.int64), w_vals_int8_t.to(torch.int64)) + y_dot_scaled = y_dot_int64.to(intermediate_dtype) * x_scales.reshape(-1, 1).to( intermediate_dtype ) y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - # Then apply weight scaling result = (y_dot_scaled * w_scales).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) result = result.to(input_tensor.dtype) - else: - if weight_tensor.act_quant_kwargs is not None: - # INT8 × INT8 (dynamic) - input_tensor = _choose_quant_func_and_quantize_tensor( - input_tensor, weight_tensor.act_quant_kwargs - ) + # FP × INT8 (weight-only) + input_tensor = input_tensor.dequantize() - x_vals_int8 = input_tensor.qdata - x_scales = input_tensor.scale - w_vals_int8_t = weight_tensor.qdata.contiguous().t() - w_scales = weight_tensor.scale + result = func(input_tensor, weight_tensor.dequantize(), None) - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype + return result + bias if bias is not None else result - # Cast fp16 scale to float to avoid overflow in y_dot_int32 - intermediate_dtype = ( - torch.float if x_scales_dtype == torch.half else x_scales_dtype - ) - y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32)) - y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape( - -1, 1 - ).to(intermediate_dtype) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - result = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - result = result.to(input_tensor.dtype) - else: - # FP × INT8 (weight-only) - result = torch.nn.functional.linear( - input_tensor, weight_tensor.dequantize(), None - ) - return result + bias if bias is not None else result +@implements([aten.slice.Tensor]) +def _(func, types, args, kwargs): + """Slice operation for Int8Tensor""" + tensor, dim, start, end, step = ( + args[0], + args[1], + args[2], + args[3], + args[4] if len(args) > 4 else 1, + ) + + # Slice scale and zero_point along dimension 0 if slicing rows + sliced_scale = tensor.scale + sliced_zero_point = tensor.zero_point + + if dim == 0 and tensor.scale.ndim >= 1: + sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step) + sliced_zero_point = aten.slice.Tensor(tensor.zero_point, 0, start, end, step) + + sliced_shape = list( + aten.slice.Tensor(torch.empty(tensor.shape), dim, start, end, step).shape + ) + + new = Int8Tensor( + aten.slice.Tensor(tensor.qdata, dim, start, end, step), + sliced_scale, + sliced_zero_point, + tensor.block_size, + sliced_shape, + tensor.act_quant_kwargs, + tensor.dtype, + ) + + return return_and_correct_aliasing(func, args, kwargs, new) Int8Tensor.__module__ = "torchao.quantization" From 590e0b709ddc01ece592ed819d6c95922e9e1a32 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 4 Oct 2025 20:05:10 +0900 Subject: [PATCH 11/13] add transpose, index selector workflows --- .../workflows/int8/test_int8_tensor.py | 121 +++++++----------- torchao/quantization/quant_api.py | 3 +- .../common/quantize_tensor_kwargs.py | 1 - .../quantize_/workflows/int8/int8_tensor.py | 75 +++++++++-- 4 files changed, 107 insertions(+), 93 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index e951793bf5..619cc2fce0 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -6,7 +6,6 @@ import copy import unittest -from contextlib import nullcontext from typing import Tuple import torch @@ -48,17 +47,6 @@ def __init__( hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device ) - # Note: tinygemm kernel only uses bfloat16 inputs - def example_inputs(self, batch_size=1): - return ( - torch.randn( - batch_size, - self.linear1.in_features, - dtype=self.dtype, - device=self.device, - ), - ) - def forward(self, x): x = self.linear1(x) x = self.linear2(x) @@ -71,9 +59,9 @@ class TestInt8Tensor(TorchAOIntegrationTestCase): def setUp(self): super().setUp() torch.manual_seed(42) - self.weight_fp = torch.randn(4, 3, dtype=torch.float32) - self.input_fp = torch.randn(4, 3, dtype=torch.float32) - self.bias = torch.randn(4) + self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16) + self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16) + self.bias = torch.randn(4, dtype=torch.bfloat16) self.block_size = [4, 3] def test_creation_and_attributes(self): @@ -86,8 +74,7 @@ def test_creation_and_attributes(self): torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - @common_utils.parametrize("compile", [False, True]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) @common_utils.parametrize( "sizes", [ @@ -105,39 +92,29 @@ def test_creation_and_attributes(self): def test_int8_linear_variants( self, dtype: torch.dtype, - compile: bool, sizes: Tuple, config, ): - error_message = None - - error_context = ( - self.assertRaisesRegex(AssertionError, error_message) - if error_message - else nullcontext() - ) - - with error_context: - M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") - # Create a linear layer - m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") - m_q = copy.deepcopy(m) + # Create a linear layer + m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") + m_q = copy.deepcopy(m) - # Quantize - quantize_(m_q, config) + # Quantize + quantize_(m_q, config) - output_original = m(input_tensor) - output_quantized = m_q(input_tensor) + output_original = m(input_tensor) + output_quantized = m_q(input_tensor) - error = compute_error(output_original, output_quantized) - assert compute_error(output_original, output_quantized) > 20, ( - f"Quantization error is too high got a SQNR of {error}" - ) + error = compute_error(output_original, output_quantized) + assert compute_error(output_original, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) def test_linear_operations(self): - """Test fp+int8 and int8+int8 linear ops with quantization error check""" + """Test fp+int8 and int8+int8 linear ops""" weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size) @@ -151,6 +128,7 @@ def test_linear_operations(self): self.assertTrue(compute_error(result_q8, reference) > 10) def test_dynamic_quantization(self): + """Test dynamic activation quantization""" weight_q8_dynamic = Int8Tensor.from_hp( self.weight_fp, self.block_size, @@ -190,58 +168,45 @@ def test_slice_preserves_aliasing(self, granularity): ], ) @common_utils.parametrize("device", ["cpu", "cuda"]) - @common_utils.parametrize("dtype", [torch.bfloat16]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_slice(self, config, device, dtype): + """Test tensor slicing""" dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) - dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) - dummy1.weight = torch.nn.Parameter( - dummy.weight.narrow(0, 0, 64), requires_grad=False - ) - dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) - dummy2.weight = torch.nn.Parameter( - dummy.weight.narrow(1, 0, 128), requires_grad=False - ) - quantize_(dummy, config) + weight1 = dummy.weight.clone().narrow(0, 0, 64) weight2 = dummy.weight.clone().narrow(1, 0, 128) - self.assertEqual( - weight1.qdata, - dummy.weight.qdata.narrow(0, 0, 64), - ) - self.assertEqual( - weight2.qdata, - dummy.weight.qdata.narrow(1, 0, 128), - ) - # check for sliced weight, before and after int8 quantization - # does not differ too much - input = torch.randn(2, 256, dtype=dtype, device=device) - res_ref = dummy1(input) - dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) - res = dummy(input) - sqnr = compute_error(res, res_ref) - self.assertTrue(sqnr > 20, f"sqnr: {sqnr}") - - input = torch.randn(2, 128, dtype=dtype, device=device) - res_ref = dummy2(input) - dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) - res = dummy(input) - sqnr = compute_error(res, res_ref) - self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) + + def test_transpose(self): + """Test transpose operation""" + weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) + transposed = weight_q8.transpose(0, 1) + + self.assertEqual(transposed.shape, (3, 4)) + self.assertEqual(transposed.block_size, [3, 4]) + + def test_select(self): + """Test select operation""" + weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) + selected = weight_q8.select(0, 0) + + self.assertEqual(selected.shape, (3,)) + + with self.assertRaises(AssertionError): + weight_q8.select(1, 0) def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" - # Test 1D tensor validation with self.assertRaises((AssertionError, ValueError, RuntimeError)): Int8Tensor.from_hp(torch.randn(5), [1]) - # Test wrong block_size validation with self.assertRaises((AssertionError, ValueError, RuntimeError)): Int8Tensor.from_hp(self.weight_fp, [1]) - # Test dequantization with exact values - test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) + test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) tensor = Int8Tensor.from_hp(test_data, [1, 2]) dequantized = torch.ops.aten.dequantize.self(tensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 69c6dca0ff..133f28bfd9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -769,8 +769,7 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): act_mapping_type: MappingType = MappingType.ASYMMETRIC layout: Layout = QDQLayout() intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8 - intx_choose_qparams_algorithm: - = ( + intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = ( IntxChooseQParamsAlgorithm.AFFINE ) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 2c3c6bcab6..664ea43b81 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -58,7 +58,6 @@ def _choose_quant_func_and_quantize_tensor( return Int8Tensor.from_hp( tensor, quant_kwargs.block_size or [1, tensor.shape[-1]], - kernel_preference=quant_kwargs.kernel_preference, ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index ff26cf2952..17c4ef260e 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -35,7 +35,6 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): """ block_size: Optional[list[int]] = None - kernel_preference: Optional[str] = None class Int8Tensor(TorchAOBaseTensor): @@ -108,7 +107,6 @@ def from_hp( w: torch.Tensor, block_size: list[int], act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - kernel_preference: Optional[str] = None, ): if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") @@ -215,7 +213,7 @@ def _(func, types, args, kwargs): # FP × INT8 (weight-only) input_tensor = input_tensor.dequantize() - result = func(input_tensor, weight_tensor.dequantize(), None) + result = func(input_tensor, weight_tensor.dequantize(input_tensor.dtype), None) return result + bias if bias is not None else result @@ -243,17 +241,70 @@ def _(func, types, args, kwargs): aten.slice.Tensor(torch.empty(tensor.shape), dim, start, end, step).shape ) - new = Int8Tensor( - aten.slice.Tensor(tensor.qdata, dim, start, end, step), - sliced_scale, - sliced_zero_point, - tensor.block_size, - sliced_shape, - tensor.act_quant_kwargs, - tensor.dtype, + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + aten.slice.Tensor(tensor.qdata, dim, start, end, step), + sliced_scale, + sliced_zero_point, + tensor.block_size, + sliced_shape, + tensor.act_quant_kwargs, + tensor.dtype, + ), ) - return return_and_correct_aliasing(func, args, kwargs, new) + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + """Dimension transposer for Int8Tensor""" + self, dim0, dim1 = args + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + self.qdata.transpose(dim0, dim1), + self.scale, + self.zero_point, + [self.block_size[dim1], self.block_size[dim0]], + [self._shape[dim1], self._shape[dim0]], + self.act_quant_kwargs, + self.dtype, + ), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Index selector for Int8Tensor""" + self, dim, index = args + assert dim == 0, f"Only dim=0 supported, got {dim}" + + # Handle 0-dim scale/zero_point (per-tensor quantization) + if self.scale.ndim == 0: + selected_scale = self.scale + selected_zero_point = self.zero_point + else: + selected_scale = self.scale[index] + selected_zero_point = self.zero_point[index] + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + self.qdata[index], + selected_scale, + selected_zero_point, + self.block_size, + list(self.qdata[index].shape), + self.act_quant_kwargs, + self.dtype, + ), + ) Int8Tensor.__module__ = "torchao.quantization" From b3d4f3e2cd606eb3b3c7827af620a0c6cac60249 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 4 Oct 2025 20:38:32 +0900 Subject: [PATCH 12/13] remove external zero point --- .../quantize_/workflows/int8/int8_tensor.py | 34 +++---------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 17c4ef260e..7664cd6fd3 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -44,7 +44,6 @@ class Int8Tensor(TorchAOBaseTensor): Tensor Attributes: qdata: (N, K) int8 quantized weight data scale: scale factors for dequantization - zero_point: zero points for dequantization Non-Tensor Attributes: block_size: block size for quantization granularity @@ -52,7 +51,7 @@ class Int8Tensor(TorchAOBaseTensor): act_quant_kwargs: flags for static/dynamic activation quantization """ - tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_data_names = ["qdata", "scale"] tensor_attribute_names = ["block_size", "_shape"] optional_tensor_attribute_names = [ "act_quant_kwargs", @@ -63,7 +62,6 @@ def __new__( cls, qdata, scale, - zero_point, block_size, shape, act_quant_kwargs=None, @@ -80,7 +78,6 @@ def __init__( self, qdata, scale, - zero_point, block_size, shape, act_quant_kwargs=None, @@ -89,7 +86,6 @@ def __init__( super().__init__() self.qdata = qdata self.scale = scale - self.zero_point = zero_point self.block_size = block_size self._shape = shape self.act_quant_kwargs = act_quant_kwargs @@ -97,8 +93,7 @@ def __init__( def __repr__(self): return ( f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " - f"{self.zero_point=}, {self.block_size=}, " - f"{self.shape=}, {self.device=}, {self.dtype=})" + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" ) @classmethod @@ -133,7 +128,6 @@ def from_hp( return cls( int_data, scale, - zero_point, block_size, w.shape, act_quant_kwargs=act_quant_kwargs, @@ -146,15 +140,12 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor qdata_fp = self.qdata.to(dtype) scale = self.scale.to(dtype) - zero_point = self.zero_point.to(dtype) - # Reshape 1D scale/zero_point to [N, 1] for broadcasting with [N, K] qdata + # Reshape 1D scale to [N, 1] for broadcasting with [N, K] qdata if scale.ndim == 1: scale = scale.unsqueeze(1) - if zero_point.ndim == 1: - zero_point = zero_point.unsqueeze(1) - return (qdata_fp - zero_point) * scale + return (qdata_fp) * scale implements = Int8Tensor.implements @@ -229,13 +220,9 @@ def _(func, types, args, kwargs): args[4] if len(args) > 4 else 1, ) - # Slice scale and zero_point along dimension 0 if slicing rows sliced_scale = tensor.scale - sliced_zero_point = tensor.zero_point - if dim == 0 and tensor.scale.ndim >= 1: sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step) - sliced_zero_point = aten.slice.Tensor(tensor.zero_point, 0, start, end, step) sliced_shape = list( aten.slice.Tensor(torch.empty(tensor.shape), dim, start, end, step).shape @@ -248,7 +235,6 @@ def _(func, types, args, kwargs): Int8Tensor( aten.slice.Tensor(tensor.qdata, dim, start, end, step), sliced_scale, - sliced_zero_point, tensor.block_size, sliced_shape, tensor.act_quant_kwargs, @@ -259,7 +245,6 @@ def _(func, types, args, kwargs): @implements(aten.transpose.int) def _(func, types, args, kwargs): - """Dimension transposer for Int8Tensor""" self, dim0, dim1 = args return return_and_correct_aliasing( func, @@ -268,7 +253,6 @@ def _(func, types, args, kwargs): Int8Tensor( self.qdata.transpose(dim0, dim1), self.scale, - self.zero_point, [self.block_size[dim1], self.block_size[dim0]], [self._shape[dim1], self._shape[dim0]], self.act_quant_kwargs, @@ -279,17 +263,10 @@ def _(func, types, args, kwargs): @implements(aten.select.int) def _(func, types, args, kwargs): - """Index selector for Int8Tensor""" self, dim, index = args assert dim == 0, f"Only dim=0 supported, got {dim}" - # Handle 0-dim scale/zero_point (per-tensor quantization) - if self.scale.ndim == 0: - selected_scale = self.scale - selected_zero_point = self.zero_point - else: - selected_scale = self.scale[index] - selected_zero_point = self.zero_point[index] + selected_scale = self.scale if self.scale.ndim == 0 else self.scale[index] return return_and_correct_aliasing( func, @@ -298,7 +275,6 @@ def _(func, types, args, kwargs): Int8Tensor( self.qdata[index], selected_scale, - selected_zero_point, self.block_size, list(self.qdata[index].shape), self.act_quant_kwargs, From df79aa8703fcfa61063b9c0bd1416f8282ee0263 Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 7 Oct 2025 19:22:08 +0900 Subject: [PATCH 13/13] update int8 quantization API --- .../workflows/int8/test_int8_tensor.py | 75 +++++----------- torchao/quantization/quant_api.py | 4 +- .../quantize_/workflows/int8/int8_tensor.py | 89 ++++++++++--------- 3 files changed, 73 insertions(+), 95 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 619cc2fce0..52ca941224 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -18,10 +18,7 @@ PerTensor, quantize_, ) -from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( - Int8Tensor, - QuantizeTensorToInt8Kwargs, -) +from torchao.quantization.quantize_.workflows.int8.int8_tensor import Int8Tensor from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @@ -109,38 +106,7 @@ def test_int8_linear_variants( output_quantized = m_q(input_tensor) error = compute_error(output_original, output_quantized) - assert compute_error(output_original, output_quantized) > 20, ( - f"Quantization error is too high got a SQNR of {error}" - ) - - def test_linear_operations(self): - """Test fp+int8 and int8+int8 linear ops""" - weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) - input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size) - - reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) - result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias) - result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias) - - self.assertEqual(result_fp.shape, reference.shape) - self.assertEqual(result_q8.shape, reference.shape) - self.assertTrue(compute_error(result_fp, reference) > 10) - self.assertTrue(compute_error(result_q8, reference) > 10) - - def test_dynamic_quantization(self): - """Test dynamic activation quantization""" - weight_q8_dynamic = Int8Tensor.from_hp( - self.weight_fp, - self.block_size, - act_quant_kwargs=QuantizeTensorToInt8Kwargs(), - ) - - reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) - result_dynamic = torch.nn.functional.linear( - self.input_fp, weight_q8_dynamic, self.bias - ) - - self.assertEqual(result_dynamic.shape, reference.shape) + assert error > 20, f"Quantization error is too high got a SQNR of {error}" @unittest.skip("granularity parameter not supported in current API") @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @@ -180,23 +146,26 @@ def test_slice(self, config, device, dtype): self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) - def test_transpose(self): - """Test transpose operation""" - weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) - transposed = weight_q8.transpose(0, 1) - - self.assertEqual(transposed.shape, (3, 4)) - self.assertEqual(transposed.block_size, [3, 4]) - - def test_select(self): - """Test select operation""" - weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) - selected = weight_q8.select(0, 0) - - self.assertEqual(selected.shape, (3,)) - - with self.assertRaises(AssertionError): - weight_q8.select(1, 0) + # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) + # Int8WeightOnlyConfig uses per-tensor (PerTensor) + if isinstance(config, Int8DynamicActivationInt8WeightConfig): + # PerRow: dim 0 slicing affects scale, dim 1 doesn't + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + self.assertEqual(weight2.scale, dummy.weight.scale) + else: + # PerTensor: scale unchanged by slicing + self.assertEqual(weight1.scale, dummy.weight.scale) + self.assertEqual(weight2.scale, dummy.weight.scale) + + def test_index_select(self): + """test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" + N, K = 256, 512 + x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) + x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) + x_int8_0 = x_int8[0] + torch.testing.assert_close( + x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 + ) def test_error_handling_and_dequant(self): """Test input validation and dequantization accuracy""" diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 133f28bfd9..eb0fdedec1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1585,11 +1585,11 @@ def get_weight_block_size(x): else: input_quant_func = _int8_asymm_per_token_quant + block_size = get_weight_block_size(weight) if config.version == 1: warnings.warn( "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" ) - block_size = get_weight_block_size(weight) quantized_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1606,7 +1606,7 @@ def get_weight_block_size(x): else: quantized_weight = Int8Tensor.from_hp( weight, - block_size=get_weight_block_size(weight), + block_size, ) return quantized_weight diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 7664cd6fd3..9dc9c7b5c3 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -34,7 +34,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): block_size (Optional[list[int]]): block size for quantization granularity """ - block_size: Optional[list[int]] = None + block_size: list[int] class Int8Tensor(TorchAOBaseTensor): @@ -47,23 +47,21 @@ class Int8Tensor(TorchAOBaseTensor): Non-Tensor Attributes: block_size: block size for quantization granularity - shape: original tensor shape act_quant_kwargs: flags for static/dynamic activation quantization """ tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["block_size", "_shape"] + tensor_attribute_names = ["block_size"] optional_tensor_attribute_names = [ "act_quant_kwargs", "dtype", ] def __new__( - cls, - qdata, - scale, - block_size, - shape, + cls: type, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], act_quant_kwargs=None, dtype=None, ): @@ -72,14 +70,13 @@ def __new__( "dtype": dtype or scale.dtype, "requires_grad": False, } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return torch.Tensor._make_wrapper_subclass(cls, list(qdata.shape), **kwargs) def __init__( self, - qdata, - scale, - block_size, - shape, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], act_quant_kwargs=None, dtype=None, ): @@ -87,7 +84,6 @@ def __init__( self.qdata = qdata self.scale = scale self.block_size = block_size - self._shape = shape self.act_quant_kwargs = act_quant_kwargs def __repr__(self): @@ -129,7 +125,6 @@ def from_hp( int_data, scale, block_size, - w.shape, act_quant_kwargs=act_quant_kwargs, dtype=w.dtype, ) @@ -141,11 +136,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor qdata_fp = self.qdata.to(dtype) scale = self.scale.to(dtype) - # Reshape 1D scale to [N, 1] for broadcasting with [N, K] qdata - if scale.ndim == 1: - scale = scale.unsqueeze(1) + # Reshape scale to broadcast + if scale.numel() > 1 and scale.shape != qdata_fp.shape: + scale = scale.view(*scale.shape, *[1] * (qdata_fp.ndim - scale.ndim)) - return (qdata_fp) * scale + return qdata_fp * scale implements = Int8Tensor.implements @@ -160,7 +155,7 @@ def _(func, types, args, kwargs): @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): """quantization: float -> int8""" - input_tensor, weight_tensor, bias = ( + activation_tensor, weight_tensor, bias = ( args[0], args[1], args[2] if len(args) > 2 else None, @@ -172,14 +167,14 @@ def _(func, types, args, kwargs): if weight_tensor.act_quant_kwargs is not None: # INT8 × INT8 (dynamic) - # Quantize input if it's not already quantized - if not isinstance(input_tensor, Int8Tensor): - input_tensor = _choose_quant_func_and_quantize_tensor( - input_tensor, weight_tensor.act_quant_kwargs + # Quantize activation if it's not already quantized + if not isinstance(activation_tensor, Int8Tensor): + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs ) - x_vals_int8 = input_tensor.qdata - x_scales = input_tensor.scale + x_vals_int8 = activation_tensor.qdata + x_scales = activation_tensor.scale w_vals_int8_t = weight_tensor.qdata.contiguous().t() w_scales = weight_tensor.scale @@ -199,12 +194,14 @@ def _(func, types, args, kwargs): result = (y_dot_scaled * w_scales).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) - result = result.to(input_tensor.dtype) + result = result.to(activation_tensor.dtype) else: # FP × INT8 (weight-only) - input_tensor = input_tensor.dequantize() + activation_tensor = activation_tensor.dequantize() - result = func(input_tensor, weight_tensor.dequantize(input_tensor.dtype), None) + result = func( + activation_tensor, weight_tensor.dequantize(activation_tensor.dtype), None + ) return result + bias if bias is not None else result @@ -220,23 +217,37 @@ def _(func, types, args, kwargs): args[4] if len(args) > 4 else 1, ) - sliced_scale = tensor.scale - if dim == 0 and tensor.scale.ndim >= 1: - sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step) + assert dim in (0, 1), f"Only dim 0 or 1 supported, got {dim}" - sliced_shape = list( - aten.slice.Tensor(torch.empty(tensor.shape), dim, start, end, step).shape - ) + if end >= tensor.shape[dim]: + end = tensor.shape[dim] + + # Always slice the qdata + sliced_qdata = func(tensor.qdata, dim, start, end, step) + + if tensor.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = tensor.scale + elif dim < tensor.scale.ndim and tensor.scale.shape[dim] > 1: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = func(tensor.scale, dim, start, end, step) + else: + sliced_scale = tensor.scale + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = list(tensor.block_size) + + for i in range(len(block_size)): + block_size[i] = min(block_size[i], sliced_qdata.shape[i]) return return_and_correct_aliasing( func, args, kwargs, Int8Tensor( - aten.slice.Tensor(tensor.qdata, dim, start, end, step), + sliced_qdata, sliced_scale, - tensor.block_size, - sliced_shape, + block_size, tensor.act_quant_kwargs, tensor.dtype, ), @@ -254,7 +265,6 @@ def _(func, types, args, kwargs): self.qdata.transpose(dim0, dim1), self.scale, [self.block_size[dim1], self.block_size[dim0]], - [self._shape[dim1], self._shape[dim0]], self.act_quant_kwargs, self.dtype, ), @@ -276,7 +286,6 @@ def _(func, types, args, kwargs): self.qdata[index], selected_scale, self.block_size, - list(self.qdata[index].shape), self.act_quant_kwargs, self.dtype, ),