1010
1111import torch
1212from executorch import exir
13- from executorch .exir import CaptureConfig , EdgeCompileConfig
13+ from executorch .exir import EdgeCompileConfig , to_edge
1414from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
1515from executorch .exir .tests .common import register_additional_test_aten_ops
1616from torch .ao .quantization import ( # @manual
2626 _convert_to_reference_decomposed_fx ,
2727 prepare_fx ,
2828)
29+ from torch .export import export
2930from torch .nn import functional as F
3031
3132from torch .testing import FileCheck
@@ -56,9 +57,11 @@ def forward(self, x, y):
5657 )
5758 m = _convert_to_reference_decomposed_fx (m )
5859 config = EdgeCompileConfig (_check_ir_validity = False )
59- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
60+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
6061 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
61- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
62+ m = m .transform (
63+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
64+ )
6265 # check that we are using functional variant of q/dq/add
6366 FileCheck ().check (
6467 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
@@ -67,12 +70,12 @@ def forward(self, x, y):
6770 ).check (
6871 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
6972 ).run (
70- m .exported_program .graph_module .code
73+ m .exported_program () .graph_module .code
7174 )
7275 m = m .to_executorch ()
7376 # check that we are using out variant of q/dq/add
7477 FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).run (
75- m .exported_program .graph_module .code
78+ m .exported_program () .graph_module .code
7679 )
7780
7881 def test_reshape (self ) -> None :
@@ -95,9 +98,11 @@ def forward(self, x, y):
9598 m (* example_inputs )
9699 m = _convert_to_reference_decomposed_fx (m )
97100 config = EdgeCompileConfig (_check_ir_validity = False )
98- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
101+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
99102 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
100- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
103+ m = m .transform (
104+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
105+ )
101106 # check that we are using functional variant of q/dq/add/reshape
102107 # make sure we only have two quant and one dequant since the q/dq around reshape
103108 # should be fused
@@ -114,14 +119,14 @@ def forward(self, x, y):
114119 1 ,
115120 exactly = True ,
116121 ).run (
117- m .exported_program .graph_module .code
122+ m .exported_program () .graph_module .code
118123 )
119124
120125 m = m .to_executorch (exir .ExecutorchBackendConfig (remove_view_copy = False ))
121126 # check that we are using out variant of q/dq/add
122127 FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).check (
123128 "torch.ops.aten.view_copy.out"
124- ).run (m .exported_program .graph_module .code )
129+ ).run (m .exported_program () .graph_module .code )
125130
126131 def test_slice (self ) -> None :
127132 """We don't proactively quantize slice today, but we'll fuse the dq-slice-q
@@ -150,9 +155,11 @@ def forward(self, x, y):
150155 )
151156 m = _convert_to_reference_decomposed_fx (m )
152157 config = EdgeCompileConfig (_check_ir_validity = False )
153- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
158+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
154159 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
155- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
160+ m = m .transform (
161+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
162+ )
156163 # check that we are using functional variant of q/dq/add/slice
157164 # make sure we only have one quant and one dequant since the q/dq around slice
158165 # should be fused
@@ -169,14 +176,14 @@ def forward(self, x, y):
169176 ).check (
170177 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
171178 ).run (
172- m .exported_program .graph_module .code
179+ m .exported_program () .graph_module .code
173180 )
174181
175182 m = m .to_executorch ()
176183 # check that we are using out variant of add and slice_copy
177184 FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).check (
178185 "torch.ops.aten.slice_copy.Tensor_out"
179- ).run (m .dump_graph_module () .code )
186+ ).run (m .exported_program (). graph_module .code )
180187
181188 def test_cat (self ) -> None :
182189 class M (torch .nn .Module ):
@@ -197,9 +204,9 @@ def forward(self, x, y):
197204 m (* example_inputs )
198205 m = _convert_to_reference_decomposed_fx (m )
199206 config = EdgeCompileConfig (_check_ir_validity = False )
200- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
207+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
201208 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
202- m = m .transform (QuantFusionPass ())
209+ m = m .transform ([ QuantFusionPass ()], check_ir_validity = False )
203210 # check that we are using functional variant of q/dq/cat
204211 FileCheck ().check_count (
205212 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
@@ -210,7 +217,7 @@ def forward(self, x, y):
210217 1 ,
211218 exactly = True ,
212219 ).run (
213- m .exported_program .graph_module .code
220+ m .exported_program () .graph_module .code
214221 )
215222
216223 m = m .to_executorch ()
@@ -224,7 +231,7 @@ def forward(self, x, y):
224231 ).check ("torch.ops.aten.cat.out" ).check_count (
225232 "torch.ops.quantized_decomposed.dequantize_per_tensor.out" , 1 , exactly = True
226233 ).run (
227- m .dump_graph_module () .code
234+ m .exported_program (). graph_module .code
228235 )
229236
230237 def test_embedding_byte (self ) -> None :
@@ -292,16 +299,18 @@ def forward(self, indices):
292299 _check_ir_validity = False ,
293300 _use_edge_ops = True ,
294301 )
295- m = exir . capture ( m , example_inputs ). to_edge ( config = compile_config )
302+ m = to_edge ( export ( m , example_inputs ), compile_config = compile_config )
296303 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
297- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
304+ m = m .transform (
305+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
306+ )
298307 # check that we are using functional variant of q/dq/cat
299308 FileCheck ().check (
300309 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default" ,
301310 ).check (
302311 "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
303312 ).run (
304- m .exported_program .graph_module .code
313+ m .exported_program () .graph_module .code
305314 )
306315
307316 # TODO: enable after the out variants of quantize_per_channel is supported
@@ -348,17 +357,18 @@ def forward(self, indices):
348357 _check_ir_validity = False ,
349358 _use_edge_ops = True ,
350359 )
351- m = exir . capture ( m , example_inputs ). to_edge ( config = compile_config )
360+ m = to_edge ( export ( m , example_inputs ), compile_config = compile_config )
352361 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
353- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
354- m (* example_inputs )
362+ m = m .transform (
363+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
364+ )
355365 # check that we are using functional variant of q/dq/cat
356366 FileCheck ().check (
357367 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default" ,
358368 ).check (
359369 "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
360370 ).run (
361- m .exported_program .graph_module .code
371+ m .exported_program () .graph_module .code
362372 )
363373
364374 # TODO: enable after the out variants of quantize_per_channel is supported
0 commit comments