-
Notifications
You must be signed in to change notification settings - Fork 344
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
08e9095
db23cf3
b861dbc
9383550
2c84ba4
8ddddd3
bd6f58a
b5cb3c8
9a51cae
c53dad0
d300b02
c43a3ec
590e0b7
b3d4f3e
df79aa8
910906b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
import unittest | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch.testing._internal import common_utils | ||
|
||
from torchao.quantization import ( | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
PerRow, | ||
PerTensor, | ||
quantize_, | ||
) | ||
from torchao.quantization.quantize_.workflows.int8.int8_tensor import Int8Tensor | ||
from torchao.quantization.utils import compute_error | ||
from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
||
|
||
# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged | ||
class ToyTwoLinearModel(torch.nn.Module): | ||
def __init__( | ||
self, | ||
input_dim, | ||
hidden_dim, | ||
output_dim, | ||
has_bias=False, | ||
dtype=None, | ||
device=None, | ||
): | ||
super().__init__() | ||
self.dtype = dtype | ||
self.device = device | ||
self.linear1 = torch.nn.Linear( | ||
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device | ||
) | ||
self.linear2 = torch.nn.Linear( | ||
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device | ||
) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@common_utils.instantiate_parametrized_tests | ||
class TestInt8Tensor(TorchAOIntegrationTestCase): | ||
def setUp(self): | ||
super().setUp() | ||
torch.manual_seed(42) | ||
self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16) | ||
self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16) | ||
self.bias = torch.randn(4, dtype=torch.bfloat16) | ||
self.block_size = [4, 3] | ||
|
||
def test_creation_and_attributes(self): | ||
"""Test tensor creation, dtypes, and ranges""" | ||
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
|
||
self.assertEqual(tensor.shape, (4, 3)) | ||
self.assertEqual(tensor.qdata.dtype, torch.int8) | ||
self.assertTrue( | ||
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) | ||
) | ||
|
||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
@common_utils.parametrize( | ||
"sizes", | ||
[ | ||
((128,), 256, 128), | ||
((32, 128), 64, 256), | ||
], | ||
) | ||
@common_utils.parametrize( | ||
"config", | ||
[ | ||
Int8DynamicActivationInt8WeightConfig(version=2), | ||
Int8WeightOnlyConfig(version=2), | ||
], | ||
) | ||
def test_int8_linear_variants( | ||
self, | ||
dtype: torch.dtype, | ||
sizes: Tuple, | ||
config, | ||
): | ||
M, N, K = sizes | ||
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | ||
|
||
# Create a linear layer | ||
m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") | ||
m_q = copy.deepcopy(m) | ||
|
||
# Quantize | ||
quantize_(m_q, config) | ||
|
||
output_original = m(input_tensor) | ||
output_quantized = m_q(input_tensor) | ||
|
||
error = compute_error(output_original, output_quantized) | ||
assert error > 20, f"Quantization error is too high got a SQNR of {error}" | ||
|
||
@unittest.skip("granularity parameter not supported in current API") | ||
@common_utils.parametrize("granularity", [PerTensor(), PerRow()]) | ||
def test_slice_preserves_aliasing(self, granularity): | ||
config = Int8DynamicActivationInt8WeightConfig( | ||
granularity=granularity, version=2 | ||
) | ||
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) | ||
l.weight = torch.nn.Parameter( | ||
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") | ||
) | ||
quantize_(l, config) | ||
param = l.weight | ||
param_data = param.data | ||
param_data = param_data.narrow(0, 0, 512) | ||
# Making sure the aliasing is preserved in sliced quantized Tensor | ||
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() | ||
assert param.data.scale.data_ptr() == param_data.scale.data_ptr() | ||
|
||
@common_utils.parametrize( | ||
"config", | ||
[ | ||
Int8DynamicActivationInt8WeightConfig(version=2), | ||
Int8WeightOnlyConfig(version=2), | ||
], | ||
) | ||
@common_utils.parametrize("device", ["cpu", "cuda"]) | ||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
def test_slice(self, config, device, dtype): | ||
"""Test tensor slicing""" | ||
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) | ||
quantize_(dummy, config) | ||
|
||
weight1 = dummy.weight.clone().narrow(0, 0, 64) | ||
weight2 = dummy.weight.clone().narrow(1, 0, 128) | ||
|
||
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) | ||
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) | ||
Comment on lines
+146
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add assert for scale as well? |
||
|
||
# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) | ||
# Int8WeightOnlyConfig uses per-tensor (PerTensor) | ||
if isinstance(config, Int8DynamicActivationInt8WeightConfig): | ||
# PerRow: dim 0 slicing affects scale, dim 1 doesn't | ||
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) | ||
self.assertEqual(weight2.scale, dummy.weight.scale) | ||
else: | ||
# PerTensor: scale unchanged by slicing | ||
self.assertEqual(weight1.scale, dummy.weight.scale) | ||
self.assertEqual(weight2.scale, dummy.weight.scale) | ||
|
||
def test_index_select(self): | ||
"""test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" | ||
N, K = 256, 512 | ||
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||
x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) | ||
x_int8_0 = x_int8[0] | ||
torch.testing.assert_close( | ||
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 | ||
) | ||
|
||
def test_error_handling_and_dequant(self): | ||
"""Test input validation and dequantization accuracy""" | ||
with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
Int8Tensor.from_hp(torch.randn(5), [1]) | ||
|
||
with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
Int8Tensor.from_hp(self.weight_fp, [1]) | ||
|
||
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) | ||
tensor = Int8Tensor.from_hp(test_data, [1, 2]) | ||
|
||
dequantized = torch.ops.aten.dequantize.self(tensor) | ||
self.assertEqual(dequantized.shape, test_data.shape) | ||
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) | ||
|
||
|
||
if __name__ == "__main__": | ||
common_utils.run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,7 @@ | |
Int4PreshuffledTensor, | ||
Int4Tensor, | ||
Int4TilePackedTo4dTensor, | ||
Int8Tensor, | ||
IntxChooseQParamsAlgorithm, | ||
IntxOpaqueTensor, | ||
IntxPackingFormat, | ||
|
@@ -1365,10 +1366,12 @@ class Int8WeightOnlyConfig(AOBaseConfig): | |
Otherwise, applies per-group quantization with the specified group size. | ||
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values | ||
for better performance with this quantization scheme. | ||
version - Version of the config to use. Version 1 uses AffineQuantization for quantization, | ||
""" | ||
|
||
group_size: Optional[int] = None | ||
set_inductor_config: bool = True | ||
version: int = 1 | ||
|
||
def __post_init__(self): | ||
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") | ||
|
@@ -1379,22 +1382,30 @@ def __post_init__(self): | |
|
||
|
||
def _int8_weight_only_quantize_tensor(weight, config): | ||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
group_size = config.group_size | ||
if group_size is None: | ||
group_size = weight.shape[-1] | ||
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) | ||
new_weight = to_affine_quantized_intx( | ||
weight, | ||
mapping_type, | ||
block_size, | ||
target_dtype, | ||
eps=eps, | ||
zero_point_dtype=zero_point_dtype, | ||
) | ||
if config.version == 1: | ||
warnings.warn( | ||
"Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" | ||
) | ||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
group_size = config.group_size | ||
if group_size is None: | ||
group_size = weight.shape[-1] | ||
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) | ||
new_weight = to_affine_quantized_intx( | ||
weight, | ||
mapping_type, | ||
block_size, | ||
target_dtype, | ||
eps=eps, | ||
zero_point_dtype=zero_point_dtype, | ||
) | ||
else: | ||
assert config.version == 2, f"Unexpected version: {config.version}" | ||
block_size = [weight.shape[0], weight.shape[1]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be the same as L1393 I think, you can extract L1390-L1393 out of the first if branch and use that I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't dividing logics much safer and easier to deprecate old API in the future? Other APIs like |
||
new_weight = Int8Tensor.from_hp(weight, block_size=block_size) | ||
return new_weight | ||
|
||
|
||
|
@@ -1522,12 +1533,14 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): | |
in original precision during decode operations. | ||
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values | ||
for better performance with this quantization scheme. | ||
version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor | ||
""" | ||
|
||
layout: Optional[Layout] = PlainLayout() | ||
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC | ||
weight_only_decode: bool = False | ||
set_inductor_config: bool = True | ||
version: int = 1 | ||
|
||
def __post_init__(self): | ||
torch._C._log_api_usage_once( | ||
|
@@ -1576,18 +1589,30 @@ def get_weight_block_size(x): | |
input_quant_func = _int8_asymm_per_token_quant | ||
|
||
block_size = get_weight_block_size(weight) | ||
new_weight = to_affine_quantized_intx( | ||
weight, | ||
mapping_type, | ||
block_size, | ||
target_dtype, | ||
eps=eps, | ||
zero_point_dtype=zero_point_dtype, | ||
_layout=layout, | ||
zero_point_domain=weight_zero_point_domain, | ||
) | ||
new_weight = to_linear_activation_quantized(new_weight, input_quant_func) | ||
return new_weight | ||
if config.version == 1: | ||
warnings.warn( | ||
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" | ||
) | ||
quantized_weight = to_affine_quantized_intx( | ||
weight, | ||
mapping_type, | ||
block_size, | ||
target_dtype, | ||
eps=eps, | ||
zero_point_dtype=zero_point_dtype, | ||
_layout=layout, | ||
zero_point_domain=weight_zero_point_domain, | ||
) | ||
quantized_weight = to_linear_activation_quantized( | ||
quantized_weight, input_quant_func | ||
) | ||
else: | ||
quantized_weight = Int8Tensor.from_hp( | ||
weight, | ||
block_size, | ||
) | ||
|
||
return quantized_weight | ||
|
||
|
||
@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( | |
""" | ||
from torchao.quantization.quantize_.workflows import ( | ||
Float8Tensor, | ||
Int8Tensor, | ||
QuantizeTensorToFloat8Kwargs, | ||
QuantizeTensorToInt8Kwargs, | ||
) | ||
|
||
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): | ||
|
@@ -52,5 +54,10 @@ def _choose_quant_func_and_quantize_tensor( | |
quant_kwargs.hp_value_ub, | ||
quant_kwargs.kernel_preference, | ||
) | ||
elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): | ||
return Int8Tensor.from_hp( | ||
tensor, | ||
quant_kwargs.block_size or [1, tensor.shape[-1]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why not make block_size mandatory? |
||
) | ||
|
||
raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Lines 158 to 216 in 8e2ca35
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Line 278 in 8e2ca35
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes linked unit test is helpful for slicing (
PerTensor
,PerRow
) test, but I didn't implementedgranularity
in this PR yet for smaller PR size. Can I address it after this PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the slicing tests are specific to a granularity, you should be able to adapt it for the currently supported granularity I think