2020from torch ._export .pass_base import PassType
2121from torchao .quantization import quantize_
2222from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
23- from torchao .quantization .pt2e .quantizer import ComposableQuantizer
2423from torchao .utils import unwrap_tensor_subclass
2524
2625
@@ -287,7 +286,7 @@ def run(self, artifact: PipelineArtifact) -> None:
287286 """
288287 if (
289288 not self ._quantization_recipe
290- or not self ._quantization_recipe .ao_base_config
289+ or not self ._quantization_recipe .ao_quantization_configs
291290 ):
292291 logging .info (
293292 "Quantization recipe is invalid to run SourceTransform, returning original artifact"
@@ -303,10 +302,11 @@ def run(self, artifact: PipelineArtifact) -> None:
303302 # Apply torchao quantize_ to each model
304303 for method_name , model in artifact .data .items ():
305304 # pyre-ignore
306- for config in self ._quantization_recipe .ao_base_config :
307- quantize_ (model , config )
305+ for ao_config in self ._quantization_recipe .ao_quantization_configs :
306+ quantize_ (model , ao_config . ao_base_config , ao_config . filter_fn )
308307 unwrap_tensor_subclass (model )
309- self ._transformed_models [method_name ] = model
308+
309+ self ._transformed_models [method_name ] = model
310310
311311 self ._artifact = artifact .copy_with_new_data (self ._transformed_models )
312312
@@ -331,6 +331,38 @@ def valid_predecessor_stages(self) -> List["StageType"]:
331331 def can_start_pipeline (self ) -> bool :
332332 return True
333333
334+ def _get_quantizer_for_prepare_pt2e (self , quantizers : List [Any ]):
335+ torch_ao_quantizers = []
336+ torchao_pt2e_quantizers = []
337+
338+ for quantizer in quantizers :
339+ from torchao .quantization .pt2e .quantizer import (
340+ Quantizer as TorchAOPT2EQuantizer ,
341+ )
342+
343+ if isinstance (quantizer , TorchAOPT2EQuantizer ):
344+ torchao_pt2e_quantizers .append (quantizer )
345+ else :
346+ torch_ao_quantizers .append (quantizer )
347+
348+ if torch_ao_quantizers and torchao_pt2e_quantizers :
349+ raise ValueError ("Mixed quantizer types are not supported" )
350+ if len (torch_ao_quantizers ) > 1 :
351+ raise ValueError (
352+ "Multiple quantizers of torch.ao.quantization.quantizer not supported"
353+ )
354+
355+ if torch_ao_quantizers :
356+ # prepare_pt2e has backward compat with torch.ao quantizer
357+ return torch_ao_quantizers [0 ]
358+ elif torchao_pt2e_quantizers :
359+ # Multiple torchao quantizers - use ComposableQuantizer
360+ from torchao .quantization .pt2e .quantizer import ComposableQuantizer
361+
362+ return ComposableQuantizer (torchao_pt2e_quantizers )
363+ else :
364+ raise ValueError ("No quantizers detected" )
365+
334366 def run (self , artifact : PipelineArtifact ) -> None :
335367 if not self ._quantization_recipe or not self ._quantization_recipe .quantizers :
336368 logging .info (
@@ -355,11 +387,10 @@ def run(self, artifact: PipelineArtifact) -> None:
355387 inputs = example_inputs [method_name ][0 ]
356388 captured_graph = torch .export .export (model , inputs , strict = True ).module ()
357389
358- composed_quantizer = ComposableQuantizer (
359- # pyre-ignore
390+ quantizer = self ._get_quantizer_for_prepare_pt2e (
360391 self ._quantization_recipe .quantizers
361392 )
362- prepared_model = prepare_pt2e (captured_graph , composed_quantizer )
393+ prepared_model = prepare_pt2e (captured_graph , quantizer )
363394
364395 for calibration_input in example_inputs [method_name ]:
365396 prepared_model (* calibration_input )
0 commit comments