@@ -98,6 +98,13 @@ def lift(
9898 )
9999 assert fake_mode is not None
100100
101+ # This map stores the names of outputs (old to new)
102+ # This is necessary to track because the output names can be changed when
103+ # we convert graph constants to placeholder inputs below.
104+ output_names = {}
105+ for output_spec in graph_signature .output_specs :
106+ output_names [output_spec .arg .name ] = output_spec .arg .name
107+
101108 # Locate the user input to insert new placeholders before them
102109 first_user_input = None
103110 for node in gm .graph .nodes :
@@ -139,9 +146,8 @@ def lift(
139146 # Replace get_attr nodes with placeholder nodes and copy metadata.
140147 with gm .graph .inserting_before (first_user_input ):
141148 # Ensure name doesn't contain period as it is used for submodules
142- const_placeholder_node = gm .graph .placeholder (
143- node .target .replace ("." , "_" )
144- )
149+ const_placeholder_name = node .target .replace ("." , "_" )
150+ const_placeholder_node = gm .graph .placeholder (const_placeholder_name )
145151 # Copy the node meta into this new placeholder node
146152 const_placeholder_node .meta = node .meta
147153
@@ -157,6 +163,12 @@ def lift(
157163 node .replace_all_uses_with (const_placeholder_node )
158164 gm .graph .erase_node (node )
159165
166+ # Verify if the const_placeholder being added is one of the output nodes
167+ # This happens if there is just a single static arange op in the graph
168+ # https://github.com/pytorch/TensorRT/issues/3189
169+ if const_placeholder_name in output_names :
170+ output_names [const_placeholder_name ] = const_placeholder_node .name
171+
160172 # Add these parameters/buffers/constants to the existing graph signature
161173 # before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
162174 input_spec_arg = TensorArgument (name = const_placeholder_node .name )
@@ -174,6 +186,11 @@ def lift(
174186 )
175187 non_user_input_idx += 1
176188
189+ # Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights)
190+ # are also the outputs of the graph
191+ for output_spec in graph_signature .output_specs :
192+ output_spec .arg .name = output_names [output_spec .arg .name ]
193+
177194 gm .graph .eliminate_dead_code ()
178195 gm .graph .lint ()
179196
0 commit comments