@@ -454,17 +454,28 @@ def _get_new_signature( # noqa: C901
454454 new_state_dict = {}
455455 new_constants = {}
456456
457- input_tensor_node_to_sig = {
458- input_spec .arg .name : input_spec
459- for input_spec in old_signature .input_specs
460- if isinstance (input_spec .arg , TensorArgument )
461- }
457+ if tag is None :
458+ # This is only the case where we're reconstructing the graph signature
459+ # for the toplevel graph
460+ placeholder_nodes = [
461+ node .name
462+ for node in original_program .graph .nodes
463+ if node .op == "placeholder"
464+ ]
465+ assert len (placeholder_nodes ) == len (old_signature .input_specs )
466+ input_node_to_sig = dict (zip (placeholder_nodes , old_signature .input_specs ))
467+ else :
468+ input_node_to_sig = {
469+ input_spec .arg .name : input_spec
470+ for input_spec in old_signature .input_specs
471+ if isinstance (input_spec .arg , TensorArgument )
472+ }
462473
463474 for node in gm .graph .nodes :
464475 is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
465476 if node .op == "placeholder" :
466477
467- if node .name not in input_tensor_node_to_sig :
478+ if node .name not in input_node_to_sig :
468479 assert tag is not None
469480 input_specs .append (
470481 InputSpec (
@@ -475,7 +486,7 @@ def _get_new_signature( # noqa: C901
475486 )
476487 continue
477488
478- orig_input_spec = input_tensor_node_to_sig [node .name ]
489+ orig_input_spec = input_node_to_sig [node .name ]
479490
480491 if not isinstance (orig_input_spec .arg , TensorArgument ):
481492 input_specs .append (orig_input_spec )
0 commit comments