44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from collections import OrderedDict
8+ from typing import cast , Mapping , Optional
9+
710import torch
8- from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
11+ from executorch .exir .dialects ._ops import ops as exir_ops
12+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
13+ from torch ._export .utils import (
14+ get_buffer ,
15+ get_lifted_tensor_constant ,
16+ get_param ,
17+ is_buffer ,
18+ is_lifted_tensor_constant ,
19+ is_param ,
20+ )
921from torch ._guards import detect_fake_mode
1022from torch .export import ExportedProgram
1123from torch .export .exported_program import InputKind , InputSpec , TensorArgument
24+ from torch .utils import _pytree as pytree
25+
26+
27+ # Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28+ # Propagating aten.full can significantly increase compiled model size.
29+ _DEFAULT_SKIP_TARGETS = {exir_ops .edge .aten .full .default }
1230
31+ _PRIMITIVE_TYPES = (
32+ float ,
33+ int ,
34+ bool ,
35+ str ,
36+ torch .Tensor ,
37+ torch .device ,
38+ torch .dtype ,
39+ torch .layout ,
40+ )
1341
14- def is_const (arg , exported_program , const_data_list ) -> bool :
42+
43+ def is_const (
44+ arg ,
45+ exported_program : ExportedProgram ,
46+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
47+ ) -> bool :
1548 if isinstance (arg , (tuple , list )):
16- return all (is_const (x , exported_program , const_data_list ) for x in arg )
49+ return all (is_const (x , exported_program , const_node_to_tensor ) for x in arg )
1750 elif isinstance (arg , dict ):
18- return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
19- elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
51+ return all (
52+ is_const (x , exported_program , const_node_to_tensor ) for x in arg .values ()
53+ )
54+ elif isinstance (arg , _PRIMITIVE_TYPES ):
55+ return True
56+ elif not isinstance (arg , torch .fx .Node ):
2057 return False
21- elif (
22- is_param (exported_program , arg )
23- or is_buffer (exported_program , arg )
24- or arg .name in const_data_list
25- ):
58+ elif arg in const_node_to_tensor :
2659 return True
2760 return False
2861
2962
30- def get_data (exported_program , arg ):
63+ def get_data (
64+ arg ,
65+ exported_program : ExportedProgram ,
66+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
67+ ):
3168 if isinstance (arg , (tuple , list )):
32- return [get_data (exported_program , x ) for x in arg ]
33- elif is_param (exported_program , arg ):
34- return get_param (exported_program , arg )
35- elif is_buffer (exported_program , arg ):
36- return get_buffer (exported_program , arg )
69+ return type (arg )(
70+ get_data (x , exported_program , const_node_to_tensor ) for x in arg
71+ )
72+ elif isinstance (arg , _PRIMITIVE_TYPES ):
73+ return arg
74+ elif arg in const_node_to_tensor :
75+ return const_node_to_tensor [arg ]
3776 return None
3877
3978
40- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
79+ def get_constant_placeholder_dict (
80+ exported_program : ExportedProgram ,
81+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
4182 """
42- This pass is for constant propagation for Exported Program with lifted parameters,
43- as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
83+ Returns a dictionary of placeholder node -> constant tensor.
4484 """
45- if (
46- len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
47- == 0
48- ):
49- return exported_program
85+ const_node_to_tensor : OrderedDict [torch .fx .Node , torch .Tensor ] = OrderedDict ()
86+ for node in exported_program .graph .nodes :
87+ if node .op != "placeholder" :
88+ continue
89+
90+ if is_param (exported_program , node ):
91+ const_node_to_tensor [node ] = cast (
92+ torch .Tensor , get_param (exported_program , node )
93+ )
94+ elif is_buffer (exported_program , node ):
95+ const_node_to_tensor [node ] = cast (
96+ torch .Tensor , get_buffer (exported_program , node )
97+ )
98+ elif is_lifted_tensor_constant (exported_program , node ):
99+ const_node_to_tensor [node ] = cast (
100+ torch .Tensor , get_lifted_tensor_constant (exported_program , node )
101+ )
102+ return const_node_to_tensor
50103
51- has_cond = [
52- node
53- for node in exported_program .graph .nodes
54- if node .target == torch .ops .higher_order .cond
55- ]
56- if len (has_cond ) > 0 :
57- raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58104
105+ def get_propagated_const_tensor_dict (
106+ exported_program : ExportedProgram ,
107+ custom_skip_targets : Optional [set [EdgeOpOverload ]],
108+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
109+ """
110+ Propagates constants and returns a dictionary of node->constant tensors.
111+ """
112+ # Initialize dict with all constant placeholders.
113+ const_node_to_tensor = get_constant_placeholder_dict (exported_program )
114+
115+ all_skip_targets : set [EdgeOpOverload ] = set ()
116+ # Default set of targets to skip.
117+ all_skip_targets .update (_DEFAULT_SKIP_TARGETS )
118+ if custom_skip_targets is not None :
119+ all_skip_targets .update (custom_skip_targets )
120+
121+ for node in exported_program .graph .nodes :
122+ if node .op != "call_function" or node .target in all_skip_targets :
123+ continue
124+
125+ if not is_const (
126+ node .args ,
127+ exported_program ,
128+ const_node_to_tensor ,
129+ ):
130+ continue
131+
132+ args_data , kwargs_data = pytree .tree_map (
133+ lambda x : get_data (x , exported_program , const_node_to_tensor ),
134+ (node .args , node .kwargs ),
135+ )
136+
137+ # Execute the `node.target` and create a new propagated constant tensor.
138+ prop_constant_tensor = node .target (* args_data , ** kwargs_data )
139+ const_node_to_tensor [node ] = prop_constant_tensor
140+
141+ return const_node_to_tensor
142+
143+
144+ def get_first_user_input (exported_program : ExportedProgram ) -> torch .fx .Node :
145+ """Returns the first user input node in the graph."""
59146 first_user_input = None
60147 for node in exported_program .graph .nodes :
61148 if (
@@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64151 ):
65152 first_user_input = node
66153 break
154+ return first_user_input
155+
156+
157+ def replace_with_constant_node (
158+ node : torch .fx .Node ,
159+ prop_constant_tensor : torch .Tensor ,
160+ first_user_input : torch .fx .Node ,
161+ fake_mode ,
162+ exported_program : ExportedProgram ,
163+ ) -> tuple [torch .fx .Node , str ]:
164+ # Add `prop_constant_tensor` to program.state_dict.
165+ prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (exported_program .constants )} "
166+ exported_program .constants [prop_constant_tensor_fqn ] = prop_constant_tensor
167+
168+ # Insert a new placeholder node for the propagated constant tensor.
169+ with exported_program .graph .inserting_before (first_user_input ):
170+ const_placeholder_node = exported_program .graph .placeholder (
171+ prop_constant_tensor_fqn
172+ )
173+
174+ # Update the meta data of the new placeholder (buffer) node.
175+ for k , v in node .meta .items ():
176+ const_placeholder_node .meta [k ] = v
177+ const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
178+ prop_constant_tensor , static_shapes = True
179+ )
180+ const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
181+
182+ # Replace the original node with the new constant node.
183+ node .replace_all_uses_with (const_placeholder_node )
184+ exported_program .graph .erase_node (node )
185+
186+ return const_placeholder_node , prop_constant_tensor_fqn
67187
68- buffers = exported_program .graph_signature .buffers
69- prop_constant_data = []
70- const_data_to_be_removed = set ()
71188
189+ def get_fake_mode (exported_program : ExportedProgram ):
72190 fake_mode = detect_fake_mode (
73191 tuple (
74192 node .meta ["val" ]
@@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77195 )
78196 )
79197 assert fake_mode is not None
198+ return fake_mode
80199
200+
201+ def erase_constant_node (
202+ exported_program : ExportedProgram ,
203+ node : torch .fx .Node ,
204+ ) -> None :
205+ # Remove corresponding tensor from param/constants dict.
206+ signature = exported_program .graph_signature
207+ if name := signature .inputs_to_parameters .pop (node .name , None ):
208+ exported_program .state_dict .pop (name , None )
209+ elif name := signature .inputs_to_lifted_tensor_constants .pop (node .name , None ):
210+ exported_program .constants .pop (name , None )
211+ elif name := signature .inputs_to_buffers .pop (node .name , None ):
212+ exported_program .constants .pop (name , None )
213+ exported_program .state_dict .pop (name , None )
214+
215+ # Remove from graph.
216+ exported_program .graph .erase_node (node )
217+
218+
219+ def create_constant_nodes_and_return_specs (
220+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
221+ exported_program : ExportedProgram ,
222+ ) -> dict [str , InputSpec ]:
223+ """
224+ Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
225+ """
226+ name_to_spec_dict : dict [str , InputSpec ] = {}
227+
228+ fake_mode = get_fake_mode (exported_program )
229+ first_user_input = get_first_user_input (exported_program )
230+
231+ # Iterate over nodes in reverse order.
232+ for node , prop_constant_tensor in reversed (const_node_to_tensor .items ()):
233+ if all (x in const_node_to_tensor for x in node .users ):
234+ # All users of this constant node are also constant, so we don't need to create a new constant node.
235+ erase_constant_node (exported_program , node )
236+ continue
237+
238+ if node .op == "placeholder" :
239+ continue
240+
241+ const_placeholder_node , prop_constant_tensor_fqn = replace_with_constant_node (
242+ node , prop_constant_tensor , first_user_input , fake_mode , exported_program
243+ )
244+
245+ # Create input spec for lifted constant.
246+ name_to_spec_dict [const_placeholder_node .name ] = InputSpec (
247+ kind = InputKind .CONSTANT_TENSOR ,
248+ arg = TensorArgument (name = const_placeholder_node .name ),
249+ target = prop_constant_tensor_fqn ,
250+ persistent = True ,
251+ )
252+ return name_to_spec_dict
253+
254+
255+ def constant_prop_pass (
256+ exported_program : ExportedProgram ,
257+ custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
258+ ) -> ExportedProgram :
259+ """
260+ This pass is for constant propagation for Exported Program with lifted parameters,
261+ as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
262+
263+ Args:
264+ exported_program: The ExportedProgram to perform constant propagation on.
265+ custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
266+
267+ Returns:
268+ The modified ExportedProgram with constant propagation applied.
269+ """
270+ if (
271+ len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
272+ == 0
273+ ):
274+ return exported_program
275+
276+ has_control_flow = [
277+ node
278+ for node in exported_program .graph .nodes
279+ if node .target == torch .ops .higher_order .cond
280+ ]
281+ if len (has_control_flow ) > 0 :
282+ raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
283+
284+ const_node_to_tensor = get_propagated_const_tensor_dict (
285+ exported_program , custom_skip_targets
286+ )
287+
288+ # Get old input specs.
289+ name_to_spec_dict = {
290+ s .arg .name : s for s in exported_program .graph_signature .input_specs
291+ }
292+ # Add the new constants to input specs dict.
293+ name_to_spec_dict .update (
294+ create_constant_nodes_and_return_specs (const_node_to_tensor , exported_program )
295+ )
296+
297+ # Generate new input spec.
298+ new_input_specs = []
81299 for node in exported_program .graph .nodes :
82- if node .op == "call_function" :
83- constant_data_name_list = [
84- input_spec .target for input_spec in prop_constant_data
85- ]
86- if is_const (node .args , exported_program , constant_data_name_list ):
87- args_data = [get_data (exported_program , arg ) for arg in node .args ]
88- kwargs_data = node .kwargs
89- const_data_to_be_removed .update (node .args )
90- prop_constant_tensor = node .target (* args_data , ** kwargs_data )
91- prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (buffers )} "
92-
93- with exported_program .graph .inserting_before (first_user_input ):
94- const_placeholder_node = exported_program .graph .placeholder (
95- prop_constant_tensor_fqn
96- )
97- # Update the meta data of the new placeholder (buffer) node
98- for k , v in node .meta .items ():
99- const_placeholder_node .meta [k ] = v
100- const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
101- prop_constant_tensor , static_shapes = True
102- )
103- const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
104-
105- node .replace_all_uses_with (const_placeholder_node )
106- exported_program .graph .erase_node (node )
107- prop_constant_node_input_spec = InputSpec (
108- kind = InputKind .BUFFER ,
109- arg = TensorArgument (name = const_placeholder_node .name ),
110- target = prop_constant_tensor_fqn ,
111- persistent = True ,
112- )
113- prop_constant_data .append (prop_constant_node_input_spec )
114- buffers .append (prop_constant_tensor_fqn )
115- exported_program .state_dict [prop_constant_tensor_fqn ] = (
116- prop_constant_tensor
117- )
118- exported_program .graph_signature .input_specs .append (
119- prop_constant_node_input_spec
120- )
121-
122- # Remove the propogated buffer from the state dict
123- for node in exported_program .graph .nodes :
124- if (
125- node .op == "placeholder"
126- and node in const_data_to_be_removed
127- and len (node .users ) == 0
128- ):
129- exported_program .state_dict .pop (node .name , None )
130- exported_program .graph .erase_node (node )
300+ if node .op != "placeholder" :
301+ continue
302+ new_input_specs .append (name_to_spec_dict [node .name ])
303+ exported_program .graph_signature .input_specs = new_input_specs
131304
305+ # Cleanup the graph.
306+ exported_program .graph .eliminate_dead_code ()
132307 exported_program .graph_module .recompile ()
308+
133309 return exported_program
0 commit comments