1717from pytensor .graph .features import AlreadyThere , Feature , ReplaceValidate
1818from pytensor .graph .op import Op
1919from pytensor .graph .traversal import (
20+ apply_toposort ,
2021 applys_between ,
2122 graph_inputs ,
22- io_toposort ,
23+ variable_bounded_apply_toposort ,
2324 vars_between ,
2425)
2526from pytensor .graph .utils import MetaObject , MissingInputError , TestValueError
26- from pytensor .misc .ordered_set import OrderedSet
2727
2828
2929ClientType = tuple [Apply , int ]
@@ -132,7 +132,6 @@ def __init__(
132132 features = []
133133
134134 self ._features : list [Feature ] = []
135-
136135 # All apply nodes in the subgraph defined by inputs and
137136 # outputs are cached in this field
138137 self .apply_nodes : set [Apply ] = set ()
@@ -355,21 +354,16 @@ def import_node(
355354 apply_node : Apply
356355 The node to be imported.
357356 check : bool
358- Check that the inputs for the imported nodes are also present in
359- the `FunctionGraph`.
357+ Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
360358 reason : str
361359 The name of the optimization or operation in progress.
362360 import_missing : bool
363361 Add missing inputs instead of raising an exception.
364362 """
365363 # We import the nodes in topological order. We only are interested in
366- # new nodes, so we use all variables we know of as if they were the
367- # input set. (The functions in the graph module only use the input set
368- # to know where to stop going down.)
369- new_nodes = io_toposort (self .variables , apply_node .outputs )
370-
371- if check :
372- for node in new_nodes :
364+ # new nodes, so we use all nodes we know of as inputs to interrupt the toposort
365+ for node in apply_toposort ([apply_node ], blockers = self .apply_nodes ):
366+ if check :
373367 for var in node .inputs :
374368 if (
375369 var .owner is None
@@ -389,8 +383,6 @@ def import_node(
389383 )
390384 raise MissingInputError (error_msg , variable = var )
391385
392- for node in new_nodes :
393- assert node not in self .apply_nodes
394386 self .apply_nodes .add (node )
395387 if not hasattr (node .tag , "imported_by" ):
396388 node .tag .imported_by = []
@@ -755,11 +747,13 @@ def toposort(self) -> list[Apply]:
755747 :meth:`FunctionGraph.orderings`.
756748
757749 """
758- if len (self .apply_nodes ) < 2 :
759- # No sorting is necessary
760- return list (self .apply_nodes )
761-
762- return io_toposort (self .inputs , self .outputs , self .orderings ())
750+ if orderings := self .orderings ():
751+ return list (
752+ variable_bounded_apply_toposort (self .outputs , self .inputs , orderings )
753+ )
754+ else :
755+ # Faster implementation when no orderings are needed
756+ return list (apply_toposort (o .owner for o in self .outputs ))
763757
764758 def orderings (self ) -> dict [Apply , list [Apply ]]:
765759 """Return a map of node to node evaluation dependencies.
@@ -778,29 +772,16 @@ def orderings(self) -> dict[Apply, list[Apply]]:
778772 take care of computing the dependencies by itself.
779773
780774 """
781- assert isinstance (self ._features , list )
782775 all_orderings : list [dict ] = []
783776
784777 for feature in self ._features :
785778 if hasattr (feature , "orderings" ):
786- orderings = feature .orderings (self )
787- if not isinstance (orderings , dict ):
788- raise TypeError (
789- "Non-deterministic return value from "
790- + str (feature .orderings )
791- + ". Nondeterministic object is "
792- + str (orderings )
793- )
794- if len (orderings ) > 0 :
779+ if orderings := feature .orderings (self ):
795780 all_orderings .append (orderings )
796- for node , prereqs in orderings .items ():
797- if not isinstance (prereqs , list | OrderedSet ):
798- raise TypeError (
799- "prereqs must be a type with a "
800- "deterministic iteration order, or toposort "
801- " will be non-deterministic."
802- )
803- if len (all_orderings ) == 1 :
781+
782+ if not all_orderings :
783+ return {}
784+ elif len (all_orderings ) == 1 :
804785 # If there is only 1 ordering, we reuse it directly.
805786 return all_orderings [0 ].copy ()
806787 else :
0 commit comments