Skip to content

Commit 6b65cc7

Browse files
committed
Add convert path for quantize_ QAT API
Summary: #1415 added a quantize_ QAT API for the prepare path. This commit adds the remaining convert path for users to actually perform end-to-end QAT using the quantize_ API. The new flow will look like: ``` from torchao.quantization import ( quantize_, int8_dynamic_activation_int4_weight, ) from torchao.quantization.qat import ( FakeQuantizeConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, intx_quantization_aware_training(activation_config, weight_config), ) quantize_(my_model, from_intx_quantization_aware_training()) quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` Test Plan: python test/quantization/test_qat.py -k test_quantize_api_convert_path ghstack-source-id: e6ea042 Pull Request resolved: #1540
1 parent b5b739b commit 6b65cc7

File tree

5 files changed

+134
-2
lines changed

5 files changed

+134
-2
lines changed

test/quantization/test_qat.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchao.quantization.qat.api import (
2626
ComposableQATQuantizer,
2727
FakeQuantizeConfig,
28+
from_intx_quantization_aware_training,
2829
intx_quantization_aware_training,
2930
)
3031
from torchao.quantization.qat.embedding import (
@@ -42,6 +43,9 @@
4243
_GenericFakeQuantize,
4344
_get_qmin_qmax,
4445
)
46+
from torchao.quantization.quant_api import (
47+
int8_dynamic_activation_int4_weight,
48+
)
4549
from torchao.quantization.quant_primitives import (
4650
MappingType,
4751
TorchAODType,
@@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self):
12621266
lambda m, _: isinstance(m, torch.nn.ReLU),
12631267
)
12641268

1269+
@unittest.skipIf(
1270+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1271+
)
1272+
def test_quantize_api_convert_path(self):
1273+
"""
1274+
Test that the following:
1275+
1276+
quantize_(model, intx_quantization_aware_training(...))
1277+
quantize_(model, from_intx_quantization_aware_training(...))
1278+
quantize_(model, int8_dynamic_activation_int4_weight())
1279+
1280+
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
1281+
"""
1282+
from torchao.quantization.qat import (
1283+
Int8DynActInt4WeightQATQuantizer,
1284+
)
1285+
1286+
group_size = 16
1287+
torch.manual_seed(self.SEED)
1288+
m = M()
1289+
baseline_model = copy.deepcopy(m)
1290+
1291+
# Baseline prepare
1292+
baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1293+
baseline_model = baseline_quantizer.prepare(baseline_model)
1294+
1295+
# quantize_ prepare
1296+
activation_config = FakeQuantizeConfig(
1297+
torch.int8,
1298+
"per_token",
1299+
is_symmetric=False,
1300+
)
1301+
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1302+
quantize_(
1303+
m,
1304+
intx_quantization_aware_training(activation_config, weight_config),
1305+
)
1306+
1307+
# Compare prepared values
1308+
torch.manual_seed(self.SEED)
1309+
x = m.example_inputs()
1310+
x2 = copy.deepcopy(x)
1311+
out = m(*x)
1312+
baseline_out = baseline_model(*x2)
1313+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1314+
1315+
# Baseline convert
1316+
baseline_model = baseline_quantizer.convert(baseline_model)
1317+
1318+
# quantize_ convert
1319+
quantize_(m, from_intx_quantization_aware_training())
1320+
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1321+
1322+
# Compare converted values
1323+
torch.manual_seed(self.SEED)
1324+
x = m.example_inputs()
1325+
x2 = copy.deepcopy(x)
1326+
out = m(*x)
1327+
baseline_out = baseline_model(*x2)
1328+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1329+
12651330

12661331
if __name__ == "__main__":
12671332
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .api import (
22
ComposableQATQuantizer,
33
FakeQuantizeConfig,
4+
from_intx_quantization_aware_training,
45
intx_quantization_aware_training,
56
)
67
from .embedding import (
@@ -18,4 +19,5 @@
1819
"Int4WeightOnlyEmbeddingQATQuantizer",
1920
"Int8DynActInt4WeightQATQuantizer",
2021
"intx_quantization_aware_training",
22+
"from_intx_quantization_aware_training",
2123
]

torchao/quantization/qat/api.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional, Union
8+
from typing import Any, Callable, List, Optional, Union
99

1010
import torch
1111

@@ -242,7 +242,7 @@ def __setattr__(self, name: str, value: Any):
242242
def intx_quantization_aware_training(
243243
activation_config: Optional[FakeQuantizeConfig] = None,
244244
weight_config: Optional[FakeQuantizeConfig] = None,
245-
) -> torch.nn.Module:
245+
) -> Callable:
246246
"""
247247
Return a function that applies fake quantization to a `torch.nn.Module`.
248248
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
@@ -295,6 +295,42 @@ def _insert_fake_quantize(mod: torch.nn.Module):
295295
return _insert_fake_quantize
296296

297297

298+
def from_intx_quantization_aware_training() -> Callable:
299+
"""
300+
Return a function that converts a model with fake quantized modules,
301+
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
302+
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
303+
back to model with the original, corresponding modules without
304+
fake quantization. This should be used with
305+
:func:`~torchao.quantization.quant_api.quantize_`.
306+
307+
Example usage::
308+
309+
from torchao.quantization import quantize_
310+
quantize_(
311+
model_with_fake_quantized_linears,
312+
from_intx_quantization_aware_training(),
313+
)
314+
"""
315+
316+
def _remove_fake_quantize(mod: torch.nn.Module):
317+
"""
318+
If the given module is a fake quantized module, return the original
319+
corresponding version of the module without fake quantization.
320+
"""
321+
from .embedding import FakeQuantizedEmbedding
322+
from .linear import FakeQuantizedLinear
323+
324+
if isinstance(mod, FakeQuantizedLinear):
325+
return mod.to_linear()
326+
elif isinstance(mod, FakeQuantizedEmbedding):
327+
return mod.to_embedding()
328+
else:
329+
return mod
330+
331+
return _remove_fake_quantize
332+
333+
298334
class ComposableQATQuantizer(TwoStepQuantizer):
299335
"""
300336
Composable quantizer that users can use to apply multiple QAT quantizers easily.

torchao/quantization/qat/embedding.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8282
self.sparse,
8383
)
8484

85+
def to_embedding(self) -> torch.nn.Embedding:
86+
new_embedding = torch.nn.Embedding(
87+
self.num_embeddings,
88+
self.embedding_dim,
89+
self.padding_idx,
90+
self.max_norm,
91+
self.norm_type,
92+
self.scale_grad_by_freq,
93+
self.sparse,
94+
device=self.weight.device,
95+
)
96+
# In distributed training, the model may be instantiated
97+
# on the meta device, in which case there is no need to
98+
# copy the weights, and doing so will result in an error
99+
if self.weight.device != torch.device("meta"):
100+
new_embedding.weight = self.weight
101+
return new_embedding
102+
85103
@classmethod
86104
def from_embedding(
87105
cls,

torchao/quantization/qat/linear.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
105105
w = self.weight
106106
return F.linear(x, w)
107107

108+
def to_linear(self) -> torch.nn.Linear:
109+
new_linear = torch.nn.Linear(
110+
self.in_features, self.out_features, self.bias, device=self.weight.device
111+
)
112+
# In distributed training, the model may be instantiated
113+
# on the meta device, in which case there is no need to
114+
# copy the weights, and doing so will result in an error
115+
if self.weight.device != torch.device("meta"):
116+
new_linear.weight = self.weight
117+
return new_linear
118+
108119
@classmethod
109120
def from_linear(
110121
cls,

0 commit comments

Comments
 (0)