@@ -970,6 +970,136 @@ def applys_between(
970970 )
971971
972972
973+ def truncated_graph_inputs (
974+ outputs : Sequence [Variable ],
975+ ancestors_to_include : Optional [Collection [Variable ]] = None ,
976+ ) -> List [Variable ]:
977+ """Get the truncate graph inputs.
978+
979+ Unlike :func:`graph_inputs` this function will return
980+ the closest nodes to outputs that do not depend on
981+ ``ancestors_to_include``. So given all the returned
982+ variables provided there is no missing node to
983+ compute the output and all nodes are independent
984+ from each other.
985+
986+ Parameters
987+ ----------
988+ outputs : Collection[Variable]
989+ Variable to get conditions for
990+ ancestors_to_include : Optional[Collection[Variable]]
991+ Additional ancestors to assume, by default None
992+
993+ Returns
994+ -------
995+ List[Variable]
996+ Variables required to compute ``outputs``
997+
998+ Examples
999+ --------
1000+ The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
1001+
1002+ * No ancestors to include
1003+
1004+ .. code-block::
1005+
1006+ n - n - (o)
1007+
1008+ * One ancestors to include
1009+
1010+ .. code-block::
1011+
1012+ n - (c) - o
1013+
1014+ * Two ancestors to include where on depends on another, both returned
1015+
1016+ .. code-block::
1017+
1018+ (c) - (c) - o
1019+
1020+ * Additional nodes are present
1021+
1022+ .. code-block::
1023+
1024+ (c) - n - o
1025+ n - (n) -'
1026+
1027+ * Disconnected ancestors to include not returned
1028+
1029+ .. code-block::
1030+
1031+ (c) - n - o
1032+ c
1033+
1034+ * Disconnected output is present and returned
1035+
1036+ .. code-block::
1037+
1038+ (c) - (c) - o
1039+ (o)
1040+
1041+ * ancestors to include that include itself adds itself
1042+
1043+ .. code-block::
1044+
1045+ n - (c) - (o/c)
1046+
1047+ """
1048+ # simple case, no additional ancestors to include
1049+ truncated_inputs = list ()
1050+ # blockers have known independent nodes and ancestors to include
1051+ candidates = list (outputs )
1052+ if not ancestors_to_include : # None or empty
1053+ # just filter out unique variables
1054+ for node in candidates :
1055+ if node not in truncated_inputs :
1056+ truncated_inputs .append (node )
1057+ # no more actions are needed
1058+ return truncated_inputs
1059+ blockers : Set [Variable ] = set (ancestors_to_include )
1060+ # enforce O(1) check for node in ancestors to include
1061+ ancestors_to_include = blockers .copy ()
1062+
1063+ while candidates :
1064+ # on any new candidate
1065+ node = candidates .pop ()
1066+ # check if the node is independent, never go above blockers
1067+ # blockers are independent nodes and ancestors to include
1068+ if node in ancestors_to_include :
1069+ # The case where node is in ancestors to include so we check if it depends on others
1070+ # it should be removed from the blockers to check against the rest
1071+ dependent = variable_depends_on (node , blockers - {node })
1072+ # ancestors to include that are present in the graph (not disconnected)
1073+ # should be added to truncated_inputs
1074+ truncated_inputs .append (node )
1075+ if dependent :
1076+ # if the ancestors to include is still dependent we need to go above,
1077+ # the search is not yet finished
1078+ # the node _has_ to have owner to be dependent
1079+ # so we do not check it
1080+ # and populate search to go above
1081+ # owner can never be None for a dependent node
1082+ candidates .extend (node .owner .inputs )
1083+ else :
1084+ # A regular node to check
1085+ dependent = variable_depends_on (node , blockers )
1086+ # all regular nodes fall to blockes
1087+ # 1. it is dependent - further search irrelevant
1088+ # 2. it is independent - the search node is inside the closure
1089+ blockers .add (node )
1090+ # if we've found an independent node and it is not in blockers so far
1091+ # it is a new indepenent node not present in ancestors to include
1092+ if not dependent :
1093+ # we've found an independent node
1094+ # do not search beyond
1095+ truncated_inputs .append (node )
1096+ else :
1097+ # populate search otherwise
1098+ # owner can never be None for a dependent node
1099+ candidates .extend (node .owner .inputs )
1100+ return truncated_inputs
1101+
1102+
9731103def clone (
9741104 inputs : List [Variable ],
9751105 outputs : List [Variable ],
0 commit comments