Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"_choose_qparams_and_quantize_affine_hqq",
"_choose_qparams_and_quantize_scale_only_hqq",
"_choose_qparams_and_quantize_affine_qqq",
"_choose_qparams_and_quantize_affine_sinq",
"_choose_scale_float8",
"_choose_qparams_gguf",
"_quantize_affine_no_zero_point",
Expand Down Expand Up @@ -2219,6 +2220,107 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor:
return qdata, scale


def _choose_qparams_and_quantize_affine_sinq(
tensor: torch.Tensor,
nbits: float = 4,
group_size: int = 64,
niter: int = 20,
compute_dtype: torch.dtype = torch.float16,
device: str = "cuda",
verbose: bool = False,
) -> tuple:
"""
SINQ: Sinkhorn-Normalized Quantization (https://www.arxiv.org/abs/2509.22944)

Iteratively normalizes row and column standard deviations to minimize
matrix imbalance before quantization with dual scales.

Args:
tensor: Input weight tensor
nbits: Number of quantization bits (default: 4)
group_size: Quantization group size (default: 64)
niter: Number of Sinkhorn iterations (default: 20)
compute_dtype: Target compute dtype (default: torch.float16)
device: Target device for computation (default: "cuda")

Returns:
Tuple of (W_q, scale_row, zero, scale_col, shape)
"""
if group_size is not None:
assert _is_divisible(tensor.numel(), group_size), (
f"group_size must divide tensor elements. shape: {tensor.shape}, group_size: {group_size}"
)

W = tensor.to(device=device, dtype=torch.float32)
shape = W.shape

# Reshape for row-wise grouping (axis=1)
if group_size is not None:
W = W.reshape([-1, group_size])

# Algorithm 1: Sinkhorn Normalization
q_min = min(W.std(dim=0).min().item(), W.std(dim=1).min().item())
q_min = max(q_min, 1e-8)

W_hat = W.clone()
q_col_acc = torch.ones(W.shape[1], device=device, dtype=torch.float32)
q_row_acc = torch.ones(W.shape[0], device=device, dtype=torch.float32)

for i in range(niter):
# Normalize columns (dim=0)
q_col = W_hat.std(dim=0) / q_min
q_col = torch.clamp(q_col, min=1e-8)
W_hat = W_hat / q_col.unsqueeze(0)
q_col_acc = q_col_acc * q_col

# Normalize rows (dim=1)
q_row = W_hat.std(dim=1) / q_min
q_row = torch.clamp(q_row, min=1e-8)
W_hat = W_hat / q_row.unsqueeze(1)
q_row_acc = q_row_acc * q_row

# RTN quantization on normalized matrix (row-wise, axis=1)
_min = W_hat.min(dim=1, keepdim=True)[0]
_max = W_hat.max(dim=1, keepdim=True)[0]

max_v = round(2**nbits - 1)
min_v = 0

scale = (max_v / (_max - _min)).clamp(max=2e4)
zero = -_min * scale

if nbits in [4]:
zero = _Round.apply(zero)

zero = zero.to(compute_dtype)
scale = scale.to(compute_dtype)
W_q = _Round.apply(W_hat * scale + zero).clamp(min_v, max_v)

# Recover scales with Sinkhorn factors (Algorithm 1, line 10)
scale = 1.0 / scale
scale_row = (scale * q_row_acc.unsqueeze(1)).reshape(shape[0], -1).to(compute_dtype)

# Expand q_col_acc to original column dimension
# q_col_acc is (group_size,), need to tile it to (shape[1],)
num_groups = shape[1] // group_size
scale_col = q_col_acc.repeat(num_groups).to(compute_dtype)

# Reshape outputs
W_q = W_q.reshape(shape)
scale_row = scale_row.reshape(shape[0], -1)
zero = zero.reshape(shape[0], -1)

W_q = W_q.to(dtype=torch.uint8, device=device)
scale_row = scale_row.to(dtype=compute_dtype, device=device)
scale_col = scale_col.to(dtype=compute_dtype, device=device)
zero = zero.to(dtype=compute_dtype, device=device)

del W, W_hat, _min, _max
torch.cuda.empty_cache()

return W_q, scale_row, zero, scale_col, shape


def _choose_qparams_affine_floatx(
tensor: torch.Tensor, ebits: int, mbits: int
) -> torch.Tensor:
Expand Down