diff --git a/test/prototype/test_codebook_quant.py b/test/prototype/test_codebook_quant.py index 4cd9c1112b..2d8d8f025f 100644 --- a/test/prototype/test_codebook_quant.py +++ b/test/prototype/test_codebook_quant.py @@ -5,7 +5,9 @@ from torchao.prototype.quantization.codebook import ( CodebookQuantizedTensor, choose_qparams_codebook, + codebook_weight_only, ) +from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error @@ -62,6 +64,11 @@ def test_codebook_quantized_tensor_from_float2(self): sqnr = compute_error(dequant, self.input) self.assertGreater(sqnr, 30) + def test_quantize_api(self): + m = torch.nn.Sequential(torch.nn.Linear(64, 64)) + quantize_(m, codebook_weight_only()) + assert type(m[0].weight) == CodebookQuantizedTensor + if __name__ == "__main__": unittest.main() diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py index b7e395b434..4566f94e8e 100644 --- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py @@ -1,14 +1,18 @@ +from dataclasses import dataclass from typing import Optional, Tuple import torch +from torchao.core.config import AOBaseConfig from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxTensor from torchao.prototype.quantization.codebook.codebook_ops import ( choose_qparams_codebook, dequantize_codebook, quantize_codebook, ) -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -254,10 +258,21 @@ def function_requires_grad_(tensor, *args, **kwargs): return tensor.requires_grad_(*args, **kwargs) -def codebook_weight_only( - dtype=torch.uint4, - block_size: Tuple[int, int] = (1, 1), - scale_block_size: int = None, +@dataclass +class CodebookWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype = torch.uint4 + block_size: Tuple[int, int] = (1, 1) + scale_block_size: int = None + + +# for bc +codebook_weight_only = CodebookWeightOnlyConfig + + +@register_quantize_module_handler(CodebookWeightOnlyConfig) +def _codebook_weight_only_transform( + module: torch.nn.Module, + config: CodebookWeightOnlyConfig, ): """ Applies codebook weight-only quantization to linear layers. @@ -269,20 +284,20 @@ def codebook_weight_only( Returns: Callable for quantization transformation. """ - - def apply_codebook_quantization(weight, scale_block_size): - if weight.numel() > 2**27: - return weight # k_means is too numerically unstable - if scale_block_size is None: - scale_block_size = weight.shape[1] - quantized = CodebookQuantizedTensor.from_float( - weight, - block_size=block_size, - code_dtype=dtype, - scale_block_size=scale_block_size, - ) - return quantized - - return _get_linear_subclass_inserter( - apply_codebook_quantization, scale_block_size=scale_block_size + dtype = config.dtype + block_size = config.block_size + scale_block_size = config.scale_block_size + weight = module.weight + + if weight.numel() > 2**27: + return module # k_means is too numerically unstable + if scale_block_size is None: + scale_block_size = weight.shape[1] + quantized_weight = CodebookQuantizedTensor.from_float( + weight, + block_size=block_size, + code_dtype=dtype, + scale_block_size=scale_block_size, ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module