1313from pathlib import Path
1414
1515import torch
16- from torch .ao .quantization .quantize_pt2e import (
17- convert_pt2e ,
18- prepare_pt2e ,
19- )
16+ from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
2017from torch .ao .quantization .quantizer .xnnpack_quantizer import (
21- XNNPACKQuantizer ,
2218 get_symmetric_quantization_config ,
19+ XNNPACKQuantizer ,
2320)
2421from torch .testing ._internal import common_utils
2522from torch .testing ._internal .common_utils import TestCase
2623
2724from torchao import quantize_
28- from torchao ._models .llama .model import Transformer , prepare_inputs_for_model
25+ from torchao ._models .llama .model import prepare_inputs_for_model , Transformer
2926from torchao ._models .llama .tokenizer import get_tokenizer
30- from torchao .dtypes import (
31- AffineQuantizedTensor ,
32- )
33- from torchao .quantization import (
34- LinearActivationQuantizedTensor ,
35- )
27+ from torchao .dtypes import AffineQuantizedTensor
28+ from torchao .quantization import LinearActivationQuantizedTensor
3629from torchao .quantization .quant_api import (
37- Quantizer ,
38- TwoStepQuantizer ,
3930 _replace_with_custom_fn_if_matches_filter ,
4031 int4_weight_only ,
4132 int8_dynamic_activation_int4_weight ,
4233 int8_dynamic_activation_int8_weight ,
4334 int8_weight_only ,
35+ Quantizer ,
36+ TwoStepQuantizer ,
4437)
45- from torchao .quantization .quant_primitives import (
46- MappingType ,
47- )
38+ from torchao .quantization .quant_primitives import MappingType
4839from torchao .quantization .subclass import (
4940 Int4WeightOnlyQuantizedLinearWeight ,
5041 Int8WeightOnlyQuantizedLinearWeight ,
5950
6051
6152def dynamic_quant (model , example_inputs ):
62- m = torch .export .export (model , example_inputs ).module ()
53+ m = torch .export .export (model , example_inputs , strict = True ).module ()
6354 quantizer = XNNPACKQuantizer ().set_global (
6455 get_symmetric_quantization_config (is_dynamic = True )
6556 )
@@ -69,7 +60,7 @@ def dynamic_quant(model, example_inputs):
6960
7061
7162def capture_and_prepare (model , example_inputs ):
72- m = torch .export .export (model , example_inputs )
63+ m = torch .export .export (model , example_inputs , strict = True )
7364 quantizer = XNNPACKQuantizer ().set_global (
7465 get_symmetric_quantization_config (is_dynamic = True )
7566 )
@@ -666,7 +657,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
666657
667658 m_unwrapped = unwrap_tensor_subclass (m )
668659
669- m = torch .export .export (m_unwrapped , example_inputs ).module ()
660+ m = torch .export .export (m_unwrapped , example_inputs , strict = True ).module ()
670661 exported_model_res = m (* example_inputs )
671662
672663 self .assertTrue (torch .equal (exported_model_res , ref ))
0 commit comments