Skip to content

Commit 6472d74

Browse files
committed
Faster graph traversal functions
* Uses simpler new apply_toposort * Removes client side-effect on previous toposort functions
1 parent 8079686 commit 6472d74

File tree

11 files changed

+353
-332
lines changed

11 files changed

+353
-332
lines changed

pytensor/graph/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def clone_get_equiv(
10001000
Keywords passed to `Apply.clone_with_new_inputs`.
10011001
10021002
"""
1003-
from pytensor.graph.traversal import io_toposort
1003+
from pytensor.graph.traversal import variable_bounded_apply_toposort
10041004

10051005
if memo is None:
10061006
memo = {}
@@ -1016,7 +1016,7 @@ def clone_get_equiv(
10161016
memo.setdefault(input, input)
10171017

10181018
# go through the inputs -> outputs graph cloning as we go
1019-
for apply in io_toposort(inputs, outputs):
1019+
for apply in variable_bounded_apply_toposort(outputs, blockers=inputs):
10201020
for input in apply.inputs:
10211021
if input not in memo:
10221022
if not isinstance(input, Constant) and copy_orphans:

pytensor/graph/fg.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate
1818
from pytensor.graph.op import Op
1919
from pytensor.graph.traversal import (
20+
apply_toposort,
2021
applys_between,
2122
graph_inputs,
22-
io_toposort,
23+
variable_bounded_apply_toposort,
2324
vars_between,
2425
)
2526
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
26-
from pytensor.misc.ordered_set import OrderedSet
2727

2828

2929
ClientType = tuple[Apply, int]
@@ -132,7 +132,6 @@ def __init__(
132132
features = []
133133

134134
self._features: list[Feature] = []
135-
136135
# All apply nodes in the subgraph defined by inputs and
137136
# outputs are cached in this field
138137
self.apply_nodes: set[Apply] = set()
@@ -355,21 +354,16 @@ def import_node(
355354
apply_node : Apply
356355
The node to be imported.
357356
check : bool
358-
Check that the inputs for the imported nodes are also present in
359-
the `FunctionGraph`.
357+
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
360358
reason : str
361359
The name of the optimization or operation in progress.
362360
import_missing : bool
363361
Add missing inputs instead of raising an exception.
364362
"""
365363
# We import the nodes in topological order. We only are interested in
366-
# new nodes, so we use all variables we know of as if they were the
367-
# input set. (The functions in the graph module only use the input set
368-
# to know where to stop going down.)
369-
new_nodes = io_toposort(self.variables, apply_node.outputs)
370-
371-
if check:
372-
for node in new_nodes:
364+
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
365+
for node in apply_toposort([apply_node], blockers=self.apply_nodes):
366+
if check:
373367
for var in node.inputs:
374368
if (
375369
var.owner is None
@@ -389,8 +383,6 @@ def import_node(
389383
)
390384
raise MissingInputError(error_msg, variable=var)
391385

392-
for node in new_nodes:
393-
assert node not in self.apply_nodes
394386
self.apply_nodes.add(node)
395387
if not hasattr(node.tag, "imported_by"):
396388
node.tag.imported_by = []
@@ -755,11 +747,13 @@ def toposort(self) -> list[Apply]:
755747
:meth:`FunctionGraph.orderings`.
756748
757749
"""
758-
if len(self.apply_nodes) < 2:
759-
# No sorting is necessary
760-
return list(self.apply_nodes)
761-
762-
return io_toposort(self.inputs, self.outputs, self.orderings())
750+
if orderings := self.orderings():
751+
return list(
752+
variable_bounded_apply_toposort(self.outputs, self.inputs, orderings)
753+
)
754+
else:
755+
# Faster implementation when no orderings are needed
756+
return list(apply_toposort(o.owner for o in self.outputs))
763757

764758
def orderings(self) -> dict[Apply, list[Apply]]:
765759
"""Return a map of node to node evaluation dependencies.
@@ -778,29 +772,16 @@ def orderings(self) -> dict[Apply, list[Apply]]:
778772
take care of computing the dependencies by itself.
779773
780774
"""
781-
assert isinstance(self._features, list)
782775
all_orderings: list[dict] = []
783776

784777
for feature in self._features:
785778
if hasattr(feature, "orderings"):
786-
orderings = feature.orderings(self)
787-
if not isinstance(orderings, dict):
788-
raise TypeError(
789-
"Non-deterministic return value from "
790-
+ str(feature.orderings)
791-
+ ". Nondeterministic object is "
792-
+ str(orderings)
793-
)
794-
if len(orderings) > 0:
779+
if orderings := feature.orderings(self):
795780
all_orderings.append(orderings)
796-
for node, prereqs in orderings.items():
797-
if not isinstance(prereqs, list | OrderedSet):
798-
raise TypeError(
799-
"prereqs must be a type with a "
800-
"deterministic iteration order, or toposort "
801-
" will be non-deterministic."
802-
)
803-
if len(all_orderings) == 1:
781+
782+
if not all_orderings:
783+
return {}
784+
elif len(all_orderings) == 1:
804785
# If there is only 1 ordering, we reuse it directly.
805786
return all_orderings[0].copy()
806787
else:

pytensor/graph/replace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
)
1111
from pytensor.graph.fg import FunctionGraph
1212
from pytensor.graph.op import Op
13-
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs
13+
from pytensor.graph.traversal import (
14+
truncated_graph_inputs,
15+
variable_bounded_apply_toposort,
16+
)
1417

1518

1619
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
@@ -295,7 +298,7 @@ def vectorize_graph(
295298
new_inputs = [replace.get(inp, inp) for inp in inputs]
296299

297300
vect_vars = dict(zip(inputs, new_inputs, strict=True))
298-
for node in io_toposort(inputs, seq_outputs):
301+
for node in variable_bounded_apply_toposort(seq_outputs, blockers=inputs):
299302
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
300303
vect_node = vectorize_node(node, *vect_inputs)
301304
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
30-
from pytensor.graph.traversal import applys_between, io_toposort, vars_between
30+
from pytensor.graph.traversal import apply_toposort, applys_between, vars_between
3131
from pytensor.graph.utils import AssocList, InconsistencyError
3232
from pytensor.misc.ordered_set import OrderedSet
3333
from pytensor.utils import flatten
@@ -1834,7 +1834,7 @@ def apply(self, fgraph, start_from=None):
18341834
callback_before = fgraph.execute_callbacks_time
18351835
nb_nodes_start = len(fgraph.apply_nodes)
18361836
t0 = time.perf_counter()
1837-
q = deque(io_toposort(fgraph.inputs, start_from))
1837+
q = deque(apply_toposort(o.owner for o in start_from))
18381838
io_t = time.perf_counter() - t0
18391839

18401840
def importer(node):
@@ -2081,11 +2081,6 @@ def add_requirements(self, fgraph):
20812081
def apply(self, fgraph, start_from=None):
20822082
change_tracker = ChangeTracker()
20832083
fgraph.attach_feature(change_tracker)
2084-
if start_from is None:
2085-
start_from = fgraph.outputs
2086-
else:
2087-
for node in start_from:
2088-
assert node in fgraph.outputs
20892084

20902085
changed = True
20912086
max_use_abort = False
@@ -2164,7 +2159,7 @@ def apply_cleanup(profs_dict):
21642159
changed |= apply_cleanup(iter_cleanup_sub_profs)
21652160

21662161
topo_t0 = time.perf_counter()
2167-
q = deque(io_toposort(fgraph.inputs, start_from))
2162+
q = deque(apply_toposort(o.owner for o in fgraph.outputs))
21682163
io_toposort_timing.append(time.perf_counter() - topo_t0)
21692164

21702165
nb_nodes.append(len(q))

0 commit comments

Comments
 (0)