|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from torchao.core.config import AOBaseConfig |
| 7 | +from torchao.dtypes import Int4CPULayout |
| 8 | +from torchao.quantization import MappingType, PerAxis, PerGroup |
| 9 | +from torchao.quantization.quant_api import ( |
| 10 | + Int4WeightOnlyConfig, |
| 11 | + Int8DynamicActivationIntxWeightConfig, |
| 12 | + IntxWeightOnlyConfig, |
| 13 | +) |
| 14 | +from torchao.quantization.quantize_.common.packing_format import PackingFormat |
| 15 | +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor |
| 16 | +from torchao.quantization.transform_module import register_quantize_module_handler |
| 17 | +from torchao.utils import check_cpu_version |
| 18 | + |
| 19 | +from .quant_api import ( |
| 20 | + choose_qparams_stretched_affine, |
| 21 | + quantize_stretched_affine, |
| 22 | + to_stretched_affine_quantized_intx, |
| 23 | +) |
| 24 | +from .uniform_torchao import ( |
| 25 | + _BIT_WIDTH_TO_DTYPE, |
| 26 | + Int4UnifTorchaoQuantizer, |
| 27 | + StretchedUnifTorchaoQuantizer, |
| 28 | +) |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig): |
| 33 | + b: Optional[int] = None |
| 34 | + quant_min: Optional[int] = None |
| 35 | + quant_max: Optional[int] = None |
| 36 | + activation_quantization: Optional[str] = "int8_asym_per_token" |
| 37 | + |
| 38 | + |
| 39 | +@register_quantize_module_handler(StretchedIntxWeightOnlyConfig) |
| 40 | +def _stretched_intx_weight_only_transform( |
| 41 | + module: torch.nn.Module, config: StretchedIntxWeightOnlyConfig |
| 42 | +) -> torch.nn.Module: |
| 43 | + weight = module.weight |
| 44 | + granularity = config.granularity |
| 45 | + mapping_type = MappingType.ASYMMETRIC |
| 46 | + |
| 47 | + assert weight.dim() == 2, ( |
| 48 | + f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" |
| 49 | + ) |
| 50 | + if isinstance(granularity, PerGroup): |
| 51 | + group_size = granularity.group_size |
| 52 | + elif isinstance(granularity, PerAxis): |
| 53 | + assert granularity.axis == 0, ( |
| 54 | + f"axis must be 0 with PerAxis, but got {granularity.axis}" |
| 55 | + ) |
| 56 | + group_size = weight.shape[-1] |
| 57 | + else: |
| 58 | + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") |
| 59 | + |
| 60 | + block_size = (1, group_size) |
| 61 | + target_dtype = torch.int8 |
| 62 | + q_args = (weight, mapping_type, block_size, target_dtype, config.b) |
| 63 | + if config.version == 2: |
| 64 | + scale, zero_point = choose_qparams_stretched_affine( |
| 65 | + *q_args, |
| 66 | + quant_min=config.quant_min, |
| 67 | + quant_max=config.quant_max, |
| 68 | + ) |
| 69 | + qdata = quantize_stretched_affine( |
| 70 | + weight, |
| 71 | + block_size, |
| 72 | + scale, |
| 73 | + zero_point, |
| 74 | + target_dtype, |
| 75 | + quant_min=config.quant_min, |
| 76 | + quant_max=config.quant_max, |
| 77 | + ) |
| 78 | + n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))] |
| 79 | + scale = scale.reshape(*n_blocks) |
| 80 | + zero_point = zero_point.reshape(*n_blocks) |
| 81 | + |
| 82 | + weight = IntxUnpackedToInt8Tensor( |
| 83 | + qdata=qdata, |
| 84 | + scale=scale, |
| 85 | + zero_point=zero_point, |
| 86 | + target_dtype=getattr(torch, f"int{config.b}"), |
| 87 | + block_size=block_size, |
| 88 | + dtype=weight.dtype, |
| 89 | + activation_quantization=config.activation_quantization, |
| 90 | + ) |
| 91 | + else: |
| 92 | + weight = to_stretched_affine_quantized_intx( |
| 93 | + *q_args, |
| 94 | + quant_min=config.quant_min, |
| 95 | + quant_max=config.quant_max, |
| 96 | + scale_dtype=config.scale_dtype, |
| 97 | + _layout=config.layout, |
| 98 | + ) |
| 99 | + module.weight = torch.nn.Parameter(weight, requires_grad=False) |
| 100 | + return module |
| 101 | + |
| 102 | + |
| 103 | +def get_config_from_quantizer( |
| 104 | + quantizer, |
| 105 | + is_embed: bool, |
| 106 | + device: torch.device, |
| 107 | + b: int, |
| 108 | + block_size: Optional[int], |
| 109 | + version: int = 2, |
| 110 | +) -> AOBaseConfig: |
| 111 | + granularity = PerGroup(block_size) if block_size is not None else PerAxis(0) |
| 112 | + weight_dtype = _BIT_WIDTH_TO_DTYPE[b] |
| 113 | + if isinstance(quantizer, Int4UnifTorchaoQuantizer): |
| 114 | + kwargs = {"layout": Int4CPULayout()} if check_cpu_version(device) else {} |
| 115 | + config = Int4WeightOnlyConfig(group_size=block_size, **kwargs) |
| 116 | + elif isinstance(quantizer, StretchedUnifTorchaoQuantizer): |
| 117 | + config = StretchedIntxWeightOnlyConfig( |
| 118 | + b=b, |
| 119 | + quant_min=quantizer.quant_min, |
| 120 | + quant_max=quantizer.quant_max, |
| 121 | + granularity=granularity, |
| 122 | + version=version, |
| 123 | + ) |
| 124 | + elif is_embed: |
| 125 | + config = IntxWeightOnlyConfig( |
| 126 | + weight_dtype=weight_dtype, |
| 127 | + granularity=granularity, |
| 128 | + mapping_type=quantizer.mapping_type, |
| 129 | + packing_format=PackingFormat.UNPACKED_TO_INT8, |
| 130 | + version=version, |
| 131 | + ) |
| 132 | + else: |
| 133 | + config = Int8DynamicActivationIntxWeightConfig( |
| 134 | + weight_dtype=weight_dtype, |
| 135 | + weight_granularity=granularity, |
| 136 | + weight_mapping_type=quantizer.mapping_type, |
| 137 | + act_mapping_type=MappingType.ASYMMETRIC, |
| 138 | + packing_format=PackingFormat.UNPACKED_TO_INT8, |
| 139 | + version=version, |
| 140 | + ) |
| 141 | + return config |
0 commit comments