55# LICENSE file in the root directory of this source tree.
66
77from dataclasses import dataclass
8- from typing import Any , Callable , List , Optional , Union
8+ from typing import Any , List , Optional , Union
99
1010import torch
1111
12+ from torchao .core .config import AOBaseConfig
1213from torchao .quantization .granularity import (
1314 Granularity ,
1415 PerAxis ,
2223 TorchAODType ,
2324 ZeroPointDomain ,
2425)
26+ from torchao .quantization .transform_module import (
27+ register_quantize_module_handler ,
28+ )
2529from torchao .quantization .unified import TwoStepQuantizer
2630
2731
@@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any):
241245 super ().__setattr__ (name , value )
242246
243247
244- def intx_quantization_aware_training (
245- activation_config : Optional [FakeQuantizeConfig ] = None ,
246- weight_config : Optional [FakeQuantizeConfig ] = None ,
247- ) -> Callable :
248+ @dataclass
249+ class IntXQuantizationAwareTrainingConfig (AOBaseConfig ):
250+ activation_config : Optional [FakeQuantizeConfig ] = None
251+ weight_config : Optional [FakeQuantizeConfig ] = None
252+
253+
254+ # for BC
255+ intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
256+
257+
258+ @register_quantize_module_handler (IntXQuantizationAwareTrainingConfig )
259+ def _intx_quantization_aware_training_transform (
260+ module : torch .nn .Module ,
261+ config : IntXQuantizationAwareTrainingConfig ,
262+ ) -> torch .nn .Module :
248263 """
249- Return a function that applies fake quantization to a `torch.nn.Module`.
264+ THIS IS NOT A PUBLIC API - any usage of this outside of torchao
265+ can break at any time.
266+
267+ Apply fake quantization to a `torch.nn.Module`.
250268 to be used with :func:`~torchao.quantization.quant_api.quantize_`.
251269
252270 Example usage::
@@ -261,45 +279,40 @@ def intx_quantization_aware_training(
261279 )
262280 quantize_(
263281 model,
264- intx_quantization_aware_training (activation_config, weight_config),
282+ IntXQuantizationAwareTrainingConfig (activation_config, weight_config),
265283 )
266284
267285 Note: If the returned function is applied on a module that is not
268286 `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
269287 `torch.nn.Embedding` with an activation config, then we will raise
270288 ValueError as these are not supported.
271289 """
272-
273- def _insert_fake_quantize (mod : torch .nn .Module ):
274- """
275- Swap the given module with its corresponding fake quantized version.
276- """
277- from .embedding import FakeQuantizedEmbedding
278- from .linear import FakeQuantizedLinear
279-
280- if isinstance (mod , torch .nn .Linear ):
281- return FakeQuantizedLinear .from_linear (
282- mod ,
283- activation_config ,
284- weight_config ,
285- )
286- elif isinstance (mod , torch .nn .Embedding ):
287- if activation_config is not None :
288- raise ValueError (
289- "Activation fake quantization is not supported for embedding"
290- )
291- return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
292- else :
290+ from .embedding import FakeQuantizedEmbedding
291+ from .linear import FakeQuantizedLinear
292+
293+ mod = module
294+ activation_config = config .activation_config
295+ weight_config = config .weight_config
296+
297+ if isinstance (mod , torch .nn .Linear ):
298+ return FakeQuantizedLinear .from_linear (
299+ mod ,
300+ activation_config ,
301+ weight_config ,
302+ )
303+ elif isinstance (mod , torch .nn .Embedding ):
304+ if activation_config is not None :
293305 raise ValueError (
294- "Module of type '%s' does not have QAT support" % type ( mod )
306+ "Activation fake quantization is not supported for embedding"
295307 )
308+ return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
309+ else :
310+ raise ValueError ("Module of type '%s' does not have QAT support" % type (mod ))
296311
297- return _insert_fake_quantize
298312
299-
300- def from_intx_quantization_aware_training () -> Callable :
313+ class FromIntXQuantizationAwareTrainingConfig (AOBaseConfig ):
301314 """
302- Return a function that converts a model with fake quantized modules,
315+ Object that knows how to convert a model with fake quantized modules,
303316 such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
304317 and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
305318 back to model with the original, corresponding modules without
@@ -311,26 +324,35 @@ def from_intx_quantization_aware_training() -> Callable:
311324 from torchao.quantization import quantize_
312325 quantize_(
313326 model_with_fake_quantized_linears,
314- from_intx_quantization_aware_training (),
327+ FromIntXQuantizationAwareTrainingConfig (),
315328 )
316329 """
317330
318- def _remove_fake_quantize (mod : torch .nn .Module ):
319- """
320- If the given module is a fake quantized module, return the original
321- corresponding version of the module without fake quantization.
322- """
323- from .embedding import FakeQuantizedEmbedding
324- from .linear import FakeQuantizedLinear
331+ pass
332+
333+
334+ # for BC
335+ from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
325336
326- if isinstance (mod , FakeQuantizedLinear ):
327- return mod .to_linear ()
328- elif isinstance (mod , FakeQuantizedEmbedding ):
329- return mod .to_embedding ()
330- else :
331- return mod
332337
333- return _remove_fake_quantize
338+ @register_quantize_module_handler (FromIntXQuantizationAwareTrainingConfig )
339+ def _from_intx_quantization_aware_training_transform (
340+ mod : torch .nn .Module ,
341+ config : FromIntXQuantizationAwareTrainingConfig ,
342+ ) -> torch .nn .Module :
343+ """
344+ If the given module is a fake quantized module, return the original
345+ corresponding version of the module without fake quantization.
346+ """
347+ from .embedding import FakeQuantizedEmbedding
348+ from .linear import FakeQuantizedLinear
349+
350+ if isinstance (mod , FakeQuantizedLinear ):
351+ return mod .to_linear ()
352+ elif isinstance (mod , FakeQuantizedEmbedding ):
353+ return mod .to_embedding ()
354+ else :
355+ return mod
334356
335357
336358class ComposableQATQuantizer (TwoStepQuantizer ):
0 commit comments