diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 0b66a8f45e..55fb67982c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -34,6 +34,7 @@ "_choose_qparams_and_quantize_affine_hqq", "_choose_qparams_and_quantize_scale_only_hqq", "_choose_qparams_and_quantize_affine_qqq", + "_choose_qparams_and_quantize_affine_sinq", "_choose_scale_float8", "_choose_qparams_gguf", "_quantize_affine_no_zero_point", @@ -2219,6 +2220,107 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor: return qdata, scale +def _choose_qparams_and_quantize_affine_sinq( + tensor: torch.Tensor, + nbits: float = 4, + group_size: int = 64, + niter: int = 20, + compute_dtype: torch.dtype = torch.float16, + device: str = "cuda", + verbose: bool = False, +) -> tuple: + """ + SINQ: Sinkhorn-Normalized Quantization (https://www.arxiv.org/abs/2509.22944) + + Iteratively normalizes row and column standard deviations to minimize + matrix imbalance before quantization with dual scales. + + Args: + tensor: Input weight tensor + nbits: Number of quantization bits (default: 4) + group_size: Quantization group size (default: 64) + niter: Number of Sinkhorn iterations (default: 20) + compute_dtype: Target compute dtype (default: torch.float16) + device: Target device for computation (default: "cuda") + + Returns: + Tuple of (W_q, scale_row, zero, scale_col, shape) + """ + if group_size is not None: + assert _is_divisible(tensor.numel(), group_size), ( + f"group_size must divide tensor elements. shape: {tensor.shape}, group_size: {group_size}" + ) + + W = tensor.to(device=device, dtype=torch.float32) + shape = W.shape + + # Reshape for row-wise grouping (axis=1) + if group_size is not None: + W = W.reshape([-1, group_size]) + + # Algorithm 1: Sinkhorn Normalization + q_min = min(W.std(dim=0).min().item(), W.std(dim=1).min().item()) + q_min = max(q_min, 1e-8) + + W_hat = W.clone() + q_col_acc = torch.ones(W.shape[1], device=device, dtype=torch.float32) + q_row_acc = torch.ones(W.shape[0], device=device, dtype=torch.float32) + + for i in range(niter): + # Normalize columns (dim=0) + q_col = W_hat.std(dim=0) / q_min + q_col = torch.clamp(q_col, min=1e-8) + W_hat = W_hat / q_col.unsqueeze(0) + q_col_acc = q_col_acc * q_col + + # Normalize rows (dim=1) + q_row = W_hat.std(dim=1) / q_min + q_row = torch.clamp(q_row, min=1e-8) + W_hat = W_hat / q_row.unsqueeze(1) + q_row_acc = q_row_acc * q_row + + # RTN quantization on normalized matrix (row-wise, axis=1) + _min = W_hat.min(dim=1, keepdim=True)[0] + _max = W_hat.max(dim=1, keepdim=True)[0] + + max_v = round(2**nbits - 1) + min_v = 0 + + scale = (max_v / (_max - _min)).clamp(max=2e4) + zero = -_min * scale + + if nbits in [4]: + zero = _Round.apply(zero) + + zero = zero.to(compute_dtype) + scale = scale.to(compute_dtype) + W_q = _Round.apply(W_hat * scale + zero).clamp(min_v, max_v) + + # Recover scales with Sinkhorn factors (Algorithm 1, line 10) + scale = 1.0 / scale + scale_row = (scale * q_row_acc.unsqueeze(1)).reshape(shape[0], -1).to(compute_dtype) + + # Expand q_col_acc to original column dimension + # q_col_acc is (group_size,), need to tile it to (shape[1],) + num_groups = shape[1] // group_size + scale_col = q_col_acc.repeat(num_groups).to(compute_dtype) + + # Reshape outputs + W_q = W_q.reshape(shape) + scale_row = scale_row.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + + W_q = W_q.to(dtype=torch.uint8, device=device) + scale_row = scale_row.to(dtype=compute_dtype, device=device) + scale_col = scale_col.to(dtype=compute_dtype, device=device) + zero = zero.to(dtype=compute_dtype, device=device) + + del W, W_hat, _min, _max + torch.cuda.empty_cache() + + return W_q, scale_row, zero, scale_col, shape + + def _choose_qparams_affine_floatx( tensor: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: