1818from executorch .examples .models .model_factory import EagerModelFactory
1919from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
2020from executorch .exir .schema import DelegateCall , Program
21- from executorch .export import export , ExportRecipe , recipe_registry
21+ from executorch .export import export , ExportRecipe , recipe_registry , StageType
2222from torch import nn
2323from torch .testing ._internal .common_quantization import TestHelperModules
24+ from torchao .quantization .utils import compute_error
2425
2526
2627class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +39,29 @@ def check_fully_delegated(self, program: Program) -> None:
3839 self .assertEqual (len (instructions ), 1 )
3940 self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
4041
42+ # pyre-ignore
43+ def _compare_eager_quantized_model_outputs (
44+ self , session , example_inputs , atol : float
45+ ) -> None :
46+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
47+ torch_export_stage_output = session .get_stage_artifacts ()[
48+ StageType .TORCH_EXPORT
49+ ]
50+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
51+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
52+ expected = eager_quantized_model (* example_inputs [0 ])
53+ Tester ._assert_outputs_equal (output , expected , atol = atol )
54+
55+ def _compare_eager_unquantized_model_outputs (
56+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
57+ ):
58+ """Utility to compare eager unquantized model output with session output using SQNR"""
59+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
60+ original_output = eager_unquantized_model (* example_inputs [0 ])
61+ error = compute_error (original_output , quantized_output )
62+ print (f"{ self ._testMethodName } - SQNR: { error } dB" )
63+ self .assertTrue (error > sqnr_threshold )
64+
4165 def test_basic_recipe (self ) -> None :
4266 m_eager = TestHelperModules .TwoLinearModule ().eval ()
4367 example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None:
4670 example_inputs = example_inputs ,
4771 export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
4872 )
49- self .assertTrue (
50- torch .allclose (
51- session .run_method ("forward" , example_inputs [0 ])[0 ],
52- m_eager (* example_inputs [0 ]),
53- atol = 1e-3 ,
54- )
55- )
73+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
5674 self .check_fully_delegated (session .get_executorch_program ())
75+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
5776
5877 def test_int8_dynamic_quant_recipe (self ) -> None :
5978 test_cases = [
60- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
79+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
6180 ]
6281
6382 for export_recipe in test_cases :
@@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
7089 example_inputs = example_inputs ,
7190 export_recipe = export_recipe ,
7291 )
73- self .assertTrue (
74- torch .allclose (
75- session .run_method ("forward" , example_inputs [0 ])[0 ],
76- m_eager (* example_inputs [0 ]),
77- atol = 1e-1 ,
78- )
92+ self ._compare_eager_quantized_model_outputs (
93+ session , example_inputs , 1e-1
7994 )
8095 self .check_fully_delegated (session .get_executorch_program ())
96+ self ._compare_eager_unquantized_model_outputs (
97+ session , m_eager , example_inputs
98+ )
8199
82100 def test_int8_static_quant_recipe (self ) -> None :
83101 test_cases = [
84- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
102+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
103+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86104 ]
87105
88106 for export_recipe in test_cases :
@@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None:
95113 example_inputs = example_inputs ,
96114 export_recipe = export_recipe ,
97115 )
98- self .assertTrue (
99- torch .allclose (
100- session .run_method ("forward" , example_inputs [0 ])[0 ],
101- m_eager (* example_inputs [0 ]),
102- atol = 1e-1 ,
103- )
116+ self ._compare_eager_quantized_model_outputs (
117+ session , example_inputs , 1e-2
104118 )
105119 self .check_fully_delegated (session .get_executorch_program ())
120+ self ._compare_eager_unquantized_model_outputs (
121+ session , m_eager , example_inputs
122+ )
106123
107124 def test_8a4w_recipe (self ) -> None :
108125 class SimpleLinearModel (nn .Module ):
@@ -116,40 +133,39 @@ def forward(self, x) -> torch.Tensor:
116133
117134 test_cases = [
118135 ExportRecipe .get_recipe (
119- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
136+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120137 ),
121138 ExportRecipe .get_recipe (
122- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123- group_size = 32 ,
139+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
140+ group_size = 8 ,
124141 ),
125142 ]
126143
127144 for export_recipe in test_cases :
128145 with self .subTest (export_recipe = export_recipe ):
129- model = SimpleLinearModel ()
146+ model = SimpleLinearModel (). eval ()
130147 example_inputs = [(torch .randn (1 , 32 ),)]
131148 session = export (
132149 model = model ,
133150 example_inputs = example_inputs ,
134151 export_recipe = export_recipe ,
135152 )
136- self .assertTrue (
137- torch .allclose (
138- session .run_method ("forward" , example_inputs [0 ])[0 ],
139- model (* example_inputs [0 ]),
140- atol = 1e-2 ,
141- )
142- )
143153 self .check_fully_delegated (session .get_executorch_program ())
154+ self ._compare_eager_quantized_model_outputs (
155+ session , example_inputs , 1e-3
156+ )
157+ self ._compare_eager_unquantized_model_outputs (
158+ session , model , example_inputs
159+ )
144160
145161 def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146162 # Map QuantType to corresponding recipe name.
147163 if quant_type == QuantType .STATIC_PER_CHANNEL :
148- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
164+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149165 elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
166+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151167 elif quant_type == QuantType .STATIC_PER_TENSOR :
152- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
168+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153169 elif quant_type == QuantType .NONE :
154170 return XNNPackRecipeType .FP32
155171 else :
@@ -224,12 +240,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224240
225241 # Should not raise any exception
226242 recipe_w_default_group = provider .create_recipe (
227- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
243+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228244 )
229245 self .assertIsNotNone (recipe_w_default_group )
230246
231247 recipe = provider .create_recipe (
232- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
248+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
249+ group_size = 64 ,
233250 )
234251 self .assertIsNotNone (recipe )
235252
@@ -240,7 +257,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240257
241258 with self .assertRaises (ValueError ) as cm :
242259 provider .create_recipe (
243- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
260+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244261 group_size = "32" , # String instead of int
245262 )
246263
0 commit comments