44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from collections import OrderedDict
8+ from typing import cast , Mapping , Optional
9+
710import torch
8- from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
11+ from executorch .exir .dialects ._ops import ops as exir_ops
12+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
13+ from torch ._export .utils import (
14+ get_buffer ,
15+ get_lifted_tensor_constant ,
16+ get_param ,
17+ is_buffer ,
18+ is_lifted_tensor_constant ,
19+ is_param ,
20+ )
921from torch ._guards import detect_fake_mode
1022from torch .export import ExportedProgram
1123from torch .export .exported_program import InputKind , InputSpec , TensorArgument
24+ from torch .utils import _pytree as pytree
25+
26+
27+ # Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28+ # Propagating aten.full can significantly increase compiled model size.
29+ _DEFAULT_SKIP_TARGETS = {exir_ops .edge .aten .full .default }
1230
31+ _PRIMITIVE_TYPES = (
32+ float ,
33+ int ,
34+ bool ,
35+ str ,
36+ torch .Tensor ,
37+ torch .device ,
38+ torch .dtype ,
39+ torch .layout ,
40+ )
1341
14- def is_const (arg , exported_program , const_data_list ) -> bool :
42+
43+ def is_const (
44+ arg ,
45+ exported_program : ExportedProgram ,
46+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
47+ ) -> bool :
1548 if isinstance (arg , (tuple , list )):
16- return all (is_const (x , exported_program , const_data_list ) for x in arg )
49+ return all (is_const (x , exported_program , const_node_to_tensor ) for x in arg )
1750 elif isinstance (arg , dict ):
18- return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
19- elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
51+ return all (
52+ is_const (x , exported_program , const_node_to_tensor ) for x in arg .values ()
53+ )
54+ elif isinstance (arg , _PRIMITIVE_TYPES ):
55+ return True
56+ elif not isinstance (arg , torch .fx .Node ):
2057 return False
21- elif (
22- is_param (exported_program , arg )
23- or is_buffer (exported_program , arg )
24- or arg .name in const_data_list
25- ):
58+ elif arg in const_node_to_tensor :
2659 return True
2760 return False
2861
2962
30- def get_data (exported_program , arg ):
63+ def get_data (
64+ arg ,
65+ exported_program : ExportedProgram ,
66+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
67+ ):
3168 if isinstance (arg , (tuple , list )):
32- return [get_data (exported_program , x ) for x in arg ]
33- elif is_param ( exported_program , arg ):
34- return get_param ( exported_program , arg )
35- elif is_buffer ( exported_program , arg ) :
36- return get_buffer ( exported_program , arg )
69+ return [get_data (x , exported_program , const_node_to_tensor ) for x in arg ]
70+ elif isinstance ( arg , _PRIMITIVE_TYPES ):
71+ return arg
72+ elif arg in const_node_to_tensor :
73+ return const_node_to_tensor [ arg ]
3774 return None
3875
3976
40- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
77+ def get_constant_placeholder_dict (
78+ exported_program : ExportedProgram ,
79+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
4180 """
42- This pass is for constant propagation for Exported Program with lifted parameters,
43- as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
81+ Returns a dictionary of placeholder node -> constant tensor.
4482 """
45- if (
46- len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
47- == 0
48- ):
49- return exported_program
83+ const_node_to_tensor : OrderedDict [torch .fx .Node , torch .Tensor ] = OrderedDict ()
84+ for node in exported_program .graph .nodes :
85+ if node .op != "placeholder" :
86+ continue
87+
88+ if is_param (exported_program , node ):
89+ const_node_to_tensor [node ] = cast (
90+ torch .Tensor , get_param (exported_program , node )
91+ )
92+ elif is_buffer (exported_program , node ):
93+ const_node_to_tensor [node ] = cast (
94+ torch .Tensor , get_buffer (exported_program , node )
95+ )
96+ elif is_lifted_tensor_constant (exported_program , node ):
97+ const_node_to_tensor [node ] = cast (
98+ torch .Tensor , get_lifted_tensor_constant (exported_program , node )
99+ )
100+ return const_node_to_tensor
50101
51- has_cond = [
52- node
53- for node in exported_program .graph .nodes
54- if node .target == torch .ops .higher_order .cond
55- ]
56- if len (has_cond ) > 0 :
57- raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58102
103+ def get_propagated_const_tensor_dict (
104+ exported_program : ExportedProgram ,
105+ custom_skip_targets : Optional [set [EdgeOpOverload ]],
106+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
107+ """
108+ Propagates constants and returns a dictionary of node->constant tensors.
109+ """
110+ # Initialize dict with all constant placeholders.
111+ const_node_to_tensor = get_constant_placeholder_dict (exported_program )
112+
113+ all_skip_targets : set [EdgeOpOverload ] = set ()
114+ # Default set of targets to skip.
115+ all_skip_targets .update (_DEFAULT_SKIP_TARGETS )
116+ if custom_skip_targets is not None :
117+ all_skip_targets .update (custom_skip_targets )
118+
119+ for node in exported_program .graph .nodes :
120+ if node .op != "call_function" or node .target in all_skip_targets :
121+ continue
122+
123+ if not is_const (
124+ node .args ,
125+ exported_program ,
126+ const_node_to_tensor ,
127+ ):
128+ continue
129+
130+ args_data , kwargs_data = pytree .tree_map (
131+ lambda x : get_data (x , exported_program , const_node_to_tensor ),
132+ (node .args , node .kwargs ),
133+ )
134+
135+ # Execute the `node.target` and create a new propagated constant tensor.
136+ prop_constant_tensor = node .target (* args_data , ** kwargs_data )
137+ const_node_to_tensor [node ] = prop_constant_tensor
138+
139+ return const_node_to_tensor
140+
141+
142+ def get_first_user_input (exported_program : ExportedProgram ) -> torch .fx .Node :
143+ """Returns the first user input node in the graph."""
59144 first_user_input = None
60145 for node in exported_program .graph .nodes :
61146 if (
@@ -64,11 +149,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64149 ):
65150 first_user_input = node
66151 break
152+ return first_user_input
153+
154+
155+ def replace_with_constant_node (
156+ node : torch .fx .Node ,
157+ prop_constant_tensor : torch .Tensor ,
158+ first_user_input : torch .fx .Node ,
159+ fake_mode ,
160+ exported_program : ExportedProgram ,
161+ ) -> tuple [torch .fx .Node , str ]:
162+ # Add `prop_constant_tensor` to program.state_dict.
163+ prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (exported_program .constants )} "
164+ exported_program .constants [prop_constant_tensor_fqn ] = prop_constant_tensor
165+
166+ # Insert a new placeholder node for the propagated constant tensor.
167+ with exported_program .graph .inserting_before (first_user_input ):
168+ const_placeholder_node = exported_program .graph .placeholder (
169+ prop_constant_tensor_fqn
170+ )
171+
172+ # Update the meta data of the new placeholder (buffer) node.
173+ for k , v in node .meta .items ():
174+ const_placeholder_node .meta [k ] = v
175+ const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
176+ prop_constant_tensor , static_shapes = True
177+ )
178+ const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
179+
180+ # Replace the original node with the new constant node.
181+ node .replace_all_uses_with (const_placeholder_node )
182+ exported_program .graph .erase_node (node )
183+
184+ return const_placeholder_node , prop_constant_tensor_fqn
67185
68- buffers = exported_program .graph_signature .buffers
69- prop_constant_data = []
70- const_data_to_be_removed = set ()
71186
187+ def get_fake_mode (exported_program : ExportedProgram ):
72188 fake_mode = detect_fake_mode (
73189 tuple (
74190 node .meta ["val" ]
@@ -77,57 +193,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77193 )
78194 )
79195 assert fake_mode is not None
196+ return fake_mode
80197
198+
199+ def erase_constant_node (
200+ exported_program : ExportedProgram ,
201+ node : torch .fx .Node ,
202+ ):
203+ # Remove from graph.
204+ exported_program .graph .erase_node (node )
205+
206+ # Remove corresponding tensor from param/constants dict.
207+ signature = exported_program .graph_signature
208+ if name := signature .inputs_to_parameters .pop (node .name , None ):
209+ exported_program .state_dict .pop (name , None )
210+ elif name := signature .inputs_to_lifted_tensor_constants .pop (node .name , None ):
211+ exported_program .constants .pop (name , None )
212+ elif name := signature .inputs_to_buffers .pop (node .name , None ):
213+ exported_program .constants .pop (name , None )
214+ exported_program .state_dict .pop (name , None )
215+
216+
217+ def create_constant_nodes_and_return_specs (
218+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
219+ exported_program : ExportedProgram ,
220+ ) -> dict [str , InputSpec ]:
221+ """
222+ Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
223+ """
224+ name_to_spec_dict : dict [str , InputSpec ] = {}
225+
226+ fake_mode = get_fake_mode (exported_program )
227+ first_user_input = get_first_user_input (exported_program )
228+
229+ # Iterate over nodes in reverse order.
230+ for node , prop_constant_tensor in reversed (const_node_to_tensor .items ()):
231+ if all (x in const_node_to_tensor for x in node .users ):
232+ # All users of this constant node are also constant, so we don't need to create a new constant node.
233+ erase_constant_node (exported_program , node )
234+ continue
235+
236+ if node .op == "placeholder" :
237+ continue
238+
239+ const_placeholder_node , prop_constant_tensor_fqn = replace_with_constant_node (
240+ node , prop_constant_tensor , first_user_input , fake_mode , exported_program
241+ )
242+
243+ # Create input spec for lifted constant.
244+ name_to_spec_dict [const_placeholder_node .name ] = InputSpec (
245+ kind = InputKind .CONSTANT_TENSOR ,
246+ arg = TensorArgument (name = const_placeholder_node .name ),
247+ target = prop_constant_tensor_fqn ,
248+ persistent = True ,
249+ )
250+ return name_to_spec_dict
251+
252+
253+ def constant_prop_pass (
254+ exported_program : ExportedProgram ,
255+ custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
256+ ) -> ExportedProgram :
257+ """
258+ This pass is for constant propagation for Exported Program with lifted parameters,
259+ as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
260+
261+ Args:
262+ exported_program: The ExportedProgram to perform constant propagation on.
263+ custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
264+
265+ Returns:
266+ The modified ExportedProgram with constant propagation applied.
267+ """
268+ if (
269+ len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
270+ == 0
271+ ):
272+ return exported_program
273+
274+ has_control_flow = [
275+ node
276+ for node in exported_program .graph .nodes
277+ if node .target == torch .ops .higher_order .cond
278+ ]
279+ if len (has_control_flow ) > 0 :
280+ raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
281+
282+ const_node_to_tensor = get_propagated_const_tensor_dict (
283+ exported_program , custom_skip_targets
284+ )
285+
286+ # Get old input specs.
287+ name_to_spec_dict = {
288+ s .arg .name : s for s in exported_program .graph_signature .input_specs
289+ }
290+ # Add the new constants to input specs dict.
291+ name_to_spec_dict .update (
292+ create_constant_nodes_and_return_specs (const_node_to_tensor , exported_program )
293+ )
294+
295+ # Generate new input spec.
296+ new_input_specs = []
81297 for node in exported_program .graph .nodes :
82- if node .op == "call_function" :
83- constant_data_name_list = [
84- input_spec .target for input_spec in prop_constant_data
85- ]
86- if is_const (node .args , exported_program , constant_data_name_list ):
87- args_data = [get_data (exported_program , arg ) for arg in node .args ]
88- kwargs_data = node .kwargs
89- const_data_to_be_removed .update (node .args )
90- prop_constant_tensor = node .target (* args_data , ** kwargs_data )
91- prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (buffers )} "
92-
93- with exported_program .graph .inserting_before (first_user_input ):
94- const_placeholder_node = exported_program .graph .placeholder (
95- prop_constant_tensor_fqn
96- )
97- # Update the meta data of the new placeholder (buffer) node
98- for k , v in node .meta .items ():
99- const_placeholder_node .meta [k ] = v
100- const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
101- prop_constant_tensor , static_shapes = True
102- )
103- const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
104-
105- node .replace_all_uses_with (const_placeholder_node )
106- exported_program .graph .erase_node (node )
107- prop_constant_node_input_spec = InputSpec (
108- kind = InputKind .BUFFER ,
109- arg = TensorArgument (name = const_placeholder_node .name ),
110- target = prop_constant_tensor_fqn ,
111- persistent = True ,
112- )
113- prop_constant_data .append (prop_constant_node_input_spec )
114- buffers .append (prop_constant_tensor_fqn )
115- exported_program .state_dict [prop_constant_tensor_fqn ] = (
116- prop_constant_tensor
117- )
118- exported_program .graph_signature .input_specs .append (
119- prop_constant_node_input_spec
120- )
121-
122- # Remove the propogated buffer from the state dict
123- for node in exported_program .graph .nodes :
124- if (
125- node .op == "placeholder"
126- and node in const_data_to_be_removed
127- and len (node .users ) == 0
128- ):
129- exported_program .state_dict .pop (node .name , None )
130- exported_program .graph .erase_node (node )
298+ if node .op != "placeholder" :
299+ continue
300+ new_input_specs .append (name_to_spec_dict [node .name ])
301+ exported_program .graph_signature .input_specs = new_input_specs
131302
303+ # Cleanup the graph.
304+ exported_program .graph .eliminate_dead_code ()
132305 exported_program .graph_module .recompile ()
306+
133307 return exported_program
0 commit comments