Skip to content

Commit b054e8d

Browse files
make etrecord set representive IO (#13130)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13052 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/33/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/33/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/33/orig @diff-train-skip-merge Differential Revision: D79386896 Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent 9d86cbe commit b054e8d

File tree

2 files changed

+370
-5
lines changed

2 files changed

+370
-5
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
6969
] = None,
7070
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
71-
_representative_inputs: Optional[List[ProgramOutput]] = None,
71+
_representative_inputs: Optional[List[ProgramInput]] = None,
7272
):
7373
self.exported_program = exported_program
7474
self.export_graph_id = export_graph_id
@@ -345,6 +345,56 @@ def add_edge_dialect_program(
345345
# Set the extracted data
346346
self.edge_dialect_program = processed_edge_dialect_program
347347

348+
def update_representative_inputs(
349+
self,
350+
representative_inputs: Union[List[ProgramInput], BundledProgram],
351+
) -> None:
352+
"""
353+
Update the representative inputs in the ETRecord.
354+
355+
This method allows users to customize the representative inputs that will be
356+
included when the ETRecord is saved. The representative inputs can be provided
357+
directly as a list or extracted from a BundledProgram.
358+
359+
Args:
360+
representative_inputs: Either a list of ProgramInput objects or a BundledProgram
361+
from which representative inputs will be extracted.
362+
"""
363+
if isinstance(representative_inputs, BundledProgram):
364+
self._representative_inputs = _get_representative_inputs(
365+
representative_inputs
366+
)
367+
else:
368+
self._representative_inputs = representative_inputs
369+
370+
def update_reference_outputs(
371+
self,
372+
reference_outputs: Union[
373+
Dict[str, List[ProgramOutput]], List[ProgramOutput], BundledProgram
374+
],
375+
) -> None:
376+
"""
377+
Update the reference outputs in the ETRecord.
378+
379+
This method allows users to customize the reference outputs that will be
380+
included when the ETRecord is saved. The reference outputs can be provided
381+
directly as a dictionary mapping method names to lists of outputs, as a
382+
single list of outputs (which will be treated as {"forward": List[ProgramOutput]}),
383+
or extracted from a BundledProgram.
384+
385+
Args:
386+
reference_outputs: Either a dictionary mapping method names to lists of
387+
ProgramOutput objects, a single list of ProgramOutput objects (treated
388+
as outputs for the "forward" method), or a BundledProgram from which
389+
reference outputs will be extracted.
390+
"""
391+
if isinstance(reference_outputs, BundledProgram):
392+
self._reference_outputs = _get_reference_outputs(reference_outputs)
393+
elif isinstance(reference_outputs, list):
394+
self._reference_outputs = {"forward": reference_outputs}
395+
else:
396+
self._reference_outputs = reference_outputs
397+
348398

349399
def _get_reference_outputs(
350400
bundled_program: BundledProgram,

0 commit comments

Comments
 (0)