Skip to content

Commit f470f49

Browse files
committed
Refactor QAT into its own module
Summary: Refactor QAT into its own module so future QAT features can live under the same folder without making qat.py longer, and a separate QAT README can be added in the future. Test Plan: python test/quantization/test_qat.py
1 parent ec317fc commit f470f49

File tree

4 files changed

+185
-158
lines changed

4 files changed

+185
-158
lines changed

test/quantization/test_qat.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212

1313
import torch
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
15-
from torchao.quantization.prototype.qat import (
15+
from torchao.quantization.prototype.qat.utils import (
1616
_choose_qparams_per_token_asymmetric,
17+
_fake_quantize_per_channel_group,
18+
_fake_quantize_per_token,
1719
_GenericFakeQuantize,
18-
fake_quantize_per_channel_group,
19-
fake_quantize_per_token,
2020
)
2121
from torchao.quantization.quant_primitives import (
2222
fake_quantize_affine,
@@ -85,7 +85,7 @@ def test_fake_quantize_per_channel_group(self):
8585
x2 = copy.deepcopy(x)
8686

8787
# fake quant op
88-
out = fake_quantize_per_channel_group(
88+
out = _fake_quantize_per_channel_group(
8989
x, s, zp, qmin, qmax, group_size,
9090
)
9191
out.sum().backward()
@@ -110,7 +110,7 @@ def test_fake_quantize_per_token(self):
110110
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
111111

112112
# fake quant op
113-
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
113+
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
114114
out.sum().backward()
115115

116116
# compare against PTQ ops
@@ -135,7 +135,7 @@ def _set_ptq_weight(
135135
Int8DynActInt4WeightLinear,
136136
WeightOnlyInt4Linear,
137137
)
138-
from torchao.quantization.prototype.qat import (
138+
from torchao.quantization.prototype.qat.api import (
139139
Int8DynActInt4WeightQATLinear,
140140
Int4WeightOnlyQATLinear,
141141
)
@@ -167,7 +167,7 @@ def _set_ptq_weight(
167167

168168
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
169169
def test_qat_8da4w_linear(self):
170-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
170+
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATLinear
171171
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
172172

173173
group_size = 128
@@ -192,7 +192,7 @@ def test_qat_8da4w_linear(self):
192192

193193
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
194194
def test_qat_8da4w_quantizer(self):
195-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
195+
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer
196196
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
197197

198198
group_size = 16
@@ -226,7 +226,7 @@ def test_qat_8da4w_quantizer(self):
226226

227227
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
228228
def test_qat_8da4w_quantizer_meta_weights(self):
229-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
229+
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer
230230

231231
with torch.device("meta"):
232232
m = M()
@@ -241,7 +241,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
241241
"""
242242
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
243243
"""
244-
from torchao.quantization.prototype.qat import (
244+
from torchao.quantization.prototype.qat.api import (
245245
Int8DynActInt4WeightQATQuantizer,
246246
disable_8da4w_fake_quant,
247247
enable_8da4w_fake_quant,
@@ -294,7 +294,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
294294
"""
295295
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
296296
"""
297-
from torchao.quantization.prototype.qat import (
297+
from torchao.quantization.prototype.qat.api import (
298298
Int8DynActInt4WeightQATQuantizer,
299299
disable_8da4w_fake_quant,
300300
)
@@ -425,7 +425,7 @@ def test_qat_4w_primitives(self):
425425
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
426426
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
427427
def test_qat_4w_linear(self):
428-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
428+
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
429429
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
430430

431431
group_size = 128
@@ -455,7 +455,7 @@ def test_qat_4w_linear(self):
455455
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
456456
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
457457
def test_qat_4w_quantizer(self):
458-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
458+
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATQuantizer
459459
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
460460

461461
group_size = 32
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .api import (
2+
disable_4w_fake_quant,
3+
disable_8da4w_fake_quant,
4+
enable_4w_fake_quant,
5+
enable_8da4w_fake_quant,
6+
Int4WeightOnlyQATQuantizer,
7+
Int8DynActInt4WeightQATQuantizer,
8+
)
9+
10+
__all__ = [
11+
"disable_4w_fake_quant",
12+
"disable_8da4w_fake_quant",
13+
"enable_4w_fake_quant",
14+
"enable_8da4w_fake_quant",
15+
"Int4WeightOnlyQATQuantizer",
16+
"Int8DynActInt4WeightQATQuantizer",
17+
]

torchao/quantization/prototype/qat.py renamed to torchao/quantization/prototype/qat/api.py

Lines changed: 10 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List, Optional, Tuple
7+
from typing import Any, Optional
88

99
import torch
1010
import torch.nn.functional as F
11-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
12-
from torch.library import impl
1311

1412
from torchao.quantization.GPTQ import (
1513
_check_linear_int4_k,
@@ -20,14 +18,13 @@
2018
Int8DynActInt4WeightLinear,
2119
WeightOnlyInt4Linear,
2220
)
23-
from torchao.quantization.quant_primitives import (
24-
fake_quantize_affine_cachemask,
25-
ZeroPointDomain,
26-
)
21+
from torchao.quantization.quant_primitives import ZeroPointDomain
2722
from torchao.quantization.unified import TwoStepQuantizer
28-
from torchao.quantization.utils import (
29-
_get_per_token_block_size,
30-
get_group_qparams_symmetric,
23+
from torchao.quantization.utils import get_group_qparams_symmetric
24+
from .utils import (
25+
_choose_qparams_per_token_asymmetric,
26+
_fake_quantize_per_channel_group,
27+
_fake_quantize_per_token,
3128
)
3229

3330

@@ -163,7 +160,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
163160
x, self.scales_precision, self.zero_points_precision,
164161
)
165162
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
166-
x_fq = fake_quantize_per_token(
163+
x_fq = _fake_quantize_per_token(
167164
x, act_scales, act_zp, act_qmin, act_qmax,
168165
)
169166
else:
@@ -177,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
177174
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
178175
weight_zp = weight_zp.to(self.zero_points_precision)
179176
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
180-
w_fq = fake_quantize_per_channel_group(
177+
w_fq = _fake_quantize_per_channel_group(
181178
self.weight,
182179
weight_scales,
183180
weight_zp,
@@ -349,7 +346,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
349346
scales, zero_points = get_groupwise_affine_qparams(
350347
self.weight, n_bit, self.groupsize, self.scales_precision,
351348
)
352-
w_fq = fake_quantize_per_channel_group(
349+
w_fq = _fake_quantize_per_channel_group(
353350
self.weight,
354351
scales,
355352
zero_points,
@@ -373,135 +370,3 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
373370
"""
374371
if isinstance(mod, Int4WeightOnlyQATLinear):
375372
mod.disable_fake_quant()
376-
377-
378-
# ========================
379-
# | QUANT PRIMITIVES |
380-
# ========================
381-
382-
class _GenericFakeQuantize(torch.autograd.Function):
383-
"""
384-
Implementation of generic fake quantize with backward STE.
385-
386-
With the appropriate input tensor shape, this can be used to express
387-
grouped per channel fake quantize or per token fake quantize.
388-
"""
389-
390-
@staticmethod
391-
def forward(
392-
ctx: torch.autograd.function.FunctionCtx,
393-
input: torch.Tensor,
394-
scales: torch.Tensor,
395-
zero_points: torch.Tensor,
396-
quant_min: int,
397-
quant_max: int,
398-
block_size: List[int],
399-
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
400-
) -> torch.Tensor:
401-
# Note: for bf16 inputs, casting them to fp32 has the unexpected
402-
# side effect of reducing memory footprint significantly, presumably
403-
# because bf16 * fp32 kernels are not as memory efficient
404-
assert input.dtype == torch.float32
405-
assert scales.dtype == torch.float32
406-
assert zero_points.dtype == torch.int32
407-
408-
(fq, mask) = fake_quantize_affine_cachemask(
409-
input,
410-
block_size,
411-
scales,
412-
zero_points,
413-
torch.int32,
414-
quant_min,
415-
quant_max,
416-
zero_point_domain,
417-
)
418-
419-
ctx.save_for_backward(mask)
420-
return fq
421-
422-
@staticmethod
423-
def backward(ctx, gy):
424-
(mask,) = ctx.saved_tensors
425-
return gy * mask, None, None, None, None, None, None
426-
427-
def fake_quantize_per_channel_group(
428-
input: torch.Tensor,
429-
scales: torch.Tensor,
430-
zero_points: torch.Tensor,
431-
quant_min: int,
432-
quant_max: int,
433-
group_size: int,
434-
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
435-
) -> torch.Tensor:
436-
assert group_size > 1
437-
assert input.shape[-1] % group_size == 0
438-
assert input.dim() == 2
439-
block_size = (1, group_size)
440-
return _GenericFakeQuantize.apply(
441-
input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain,
442-
)
443-
444-
def fake_quantize_per_token(
445-
input: torch.Tensor,
446-
scales: torch.Tensor,
447-
zero_points: torch.Tensor,
448-
quant_min: int,
449-
quant_max: int,
450-
) -> torch.Tensor:
451-
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check
452-
453-
_per_token_quant_qparam_dim_check(input, scales, zero_points)
454-
block_size = _get_per_token_block_size(input)
455-
fq_input = input.to(torch.float32)
456-
fq = _GenericFakeQuantize.apply(
457-
fq_input, scales, zero_points, quant_min, quant_max, block_size,
458-
)
459-
return fq.reshape_as(input).to(input.dtype)
460-
461-
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
462-
# The version in pytorch does not have backward support yet so we add
463-
# it here for now until https://github.com/pytorch/pytorch/pull/123452
464-
# is landed.
465-
def _choose_qparams_per_token_asymmetric(
466-
input: torch.Tensor,
467-
scales_precision: torch.dtype = torch.float32,
468-
zero_points_precision: torch.dtype = torch.float32,
469-
) -> Tuple[torch.Tensor, torch.Tensor]:
470-
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
471-
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
472-
every N elements with the same quantization parameter. The dimension for scales/zero_points
473-
will be (M1 * M2 ... * Mn)
474-
475-
Args:
476-
input (torch.Tensor): original float32/float16 Tensor
477-
scales_precision (torch.dtype): precision of returned scales
478-
zero_points_precision (torch.dtype): precision of returned zero points
479-
480-
Returns:
481-
scales and zero_points, both float32 Tensors
482-
"""
483-
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
484-
qmin, qmax = -128, 127
485-
min_val = torch.amin(input, dim=-1, keepdim=True)
486-
max_val = torch.amax(input, dim=-1, keepdim=True)
487-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
488-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
489-
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
490-
491-
# scale
492-
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
493-
scale = scale.clamp(min=eps)
494-
495-
# zero point
496-
descaled_min = min_val_neg / scale
497-
descaled_max = max_val_pos / scale
498-
zero_point_from_min_error = qmin + descaled_min
499-
zero_point_from_max_error = qmax + descaled_max
500-
zero_point = torch.where(
501-
zero_point_from_min_error + zero_point_from_max_error > 0,
502-
qmin - descaled_min,
503-
qmax - descaled_max,
504-
)
505-
zero_point = torch.clamp(zero_point, qmin, qmax).round()
506-
507-
return scale.to(scales_precision), zero_point.to(zero_points_precision)

0 commit comments

Comments
 (0)