Skip to content

Commit c4b00d6

Browse files
committed
add graph replace
1 parent 96906a7 commit c4b00d6

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

pytensor/graph/utils.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,17 @@
33
import traceback
44
from abc import ABCMeta
55
from io import StringIO
6-
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Dict,
10+
List,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
TypeVar,
15+
Union,
16+
)
717

818

919
if TYPE_CHECKING:
@@ -420,3 +430,49 @@ def toposort(prereqs_d):
420430
"some orderings contain invalid elements."
421431
)
422432
return seq
433+
434+
435+
def graph_replace(
436+
outputs: Sequence["Variable"], replace: Dict["Variable", "Variable"]
437+
) -> List["Variable"]:
438+
from pytensor.graph.basic import condition_subset
439+
from pytensor.graph.fg import FunctionGraph
440+
441+
# collect minimum graph inputs which is required to compute outputs
442+
# and depend on replacements
443+
conditions = list(condition_subset(outputs, replace))
444+
# for the function graph we need the clean graph where
445+
# inputs do not have owners
446+
# this is exactly the reason to clone conditions
447+
equiv = {c: c.clone() for c in conditions}
448+
# some replace keys may dissapear
449+
# the reason is they are inside the graph
450+
# and depend on some vars in conditions
451+
# but we need to keep references to make replacements
452+
# we clone the replace keys to get them
453+
equiv.update({r: r.clone() for r in replace})
454+
# clone the graph but preserve the equiv mapping
455+
fg = FunctionGraph(
456+
conditions,
457+
outputs,
458+
# clone_get_equiv kwargs
459+
copy_orphans=False,
460+
copy_inputs=False,
461+
memo=equiv,
462+
)
463+
# replace the conditions back
464+
fg_replace = {equiv[c]: c for c in conditions}
465+
# add the replacements on top of input mappings
466+
fg_replace.update({equiv[r]: v for r, v in replace.items()})
467+
# replacements have to be done in reverse topological order so that nested
468+
# expressions get recursively replaced correctly
469+
toposort = fg.toposort()
470+
sorted_replacements = sorted(
471+
tuple(fg_replace.items()),
472+
# sort based on the fg toposort, if a variable has no owner, it goes first
473+
key=lambda pair: (toposort.index(pair[0].owner) if pair[0].owner else -1),
474+
reverse=True,
475+
)
476+
fg.replace_all(sorted_replacements, import_missing=True)
477+
478+
return list(fg.outputs)

tests/graph/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytensor
2+
from pytensor.graph.utils import graph_replace
23
from pytensor.tensor.type import vector
4+
from tests.graph.utils import MyOp, MyVariable
35

46

57
def test_stack_trace():
@@ -12,3 +14,24 @@ def test_stack_trace():
1214
v = vector()
1315
assert len(v.tag.trace) == 1
1416
assert len(v.tag.trace[0]) == 2
17+
18+
19+
def test_replacements():
20+
x = MyVariable("x")
21+
y = MyVariable("y")
22+
x2 = MyOp("xop")(x)
23+
x2.name = "x2"
24+
y2 = MyOp("yop")(y)
25+
y2.name = "y2"
26+
27+
yc = graph_replace([x2], {x: y2})[0]
28+
assert yc.owner.inputs[0] is y2
29+
30+
# the case where inputs have to be replaed in reverse topological order
31+
o = MyOp("xyop")(x2, y2)
32+
new_x = x.clone()
33+
new_y2 = y2.clone()
34+
35+
oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
36+
assert oc.owner.inputs[1] is new_y2
37+
assert oc.owner.inputs[0].owner.inputs[0] is new_x

0 commit comments

Comments
 (0)