diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 36b886044..9fcc0d55d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -142,6 +142,18 @@ def is_preset_scheme(name: str) -> bool: ), ) +# 4 bit integer weights only asymmetric quantization +W4A16_ASYM = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=False, + dynamic=False, + ), +) + # 4 bit integer weights and 8 bit activations quantization INT8_W4A8 = dict( weights=QuantizationArgs( @@ -205,6 +217,7 @@ def is_preset_scheme(name: str) -> bool: # Integer weight only schemes "W8A16": W8A16, "W4A16": W4A16, + "W4A16_ASYM": W4A16_ASYM, # Integer weight and activation schemes "W8A8": INT8_W8A8, "INT8": INT8_W8A8, # alias for W8A8 diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 9f65ee330..d7e6d5f81 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -64,8 +64,11 @@ def calculate_qparams( :param quantization_args: settings to quantization :return: tuple of the calculated scale(s) and zero point(s) """ + # based on the implementations for consuming quantized values, + # 0.0 must always be representable within the quantized range min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device bit_min, bit_max = calculate_range(quantization_args, device) @@ -84,6 +87,9 @@ def calculate_qparams( zero_points = torch.clamp(zero_points, bit_min, bit_max) # match zero-points to quantized type + # if casting to int, use round instead of truncate + if quantization_args.type == QuantizationType.INT: + zero_points = torch.round(zero_points) zero_points = zero_points.to(zp_dtype) if scales.ndim == 0: @@ -96,7 +102,7 @@ def calculate_qparams( def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): """ Returns the computed scales and zero points for dynamic activation - qunatization. + quantization. :param value: tensor to calculate quantization parameters for :param args: quantization args