|
| 1 | +from functools import partial |
1 | 2 | from typing import ( |
2 | 3 | Collection, |
3 | 4 | Dict, |
|
10 | 11 | cast, |
11 | 12 | ) |
12 | 13 |
|
13 | | -from pytensor.graph.basic import Constant, Variable |
| 14 | +from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs |
| 15 | +from pytensor.graph.fg import FunctionGraph |
14 | 16 |
|
15 | 17 |
|
16 | 18 | def clone_replace( |
@@ -58,3 +60,92 @@ def clone_replace( |
58 | 60 | _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds) |
59 | 61 |
|
60 | 62 | return cast(List[Variable], outs) |
| 63 | + |
| 64 | + |
| 65 | +def graph_replace( |
| 66 | + outputs: Sequence[Variable], |
| 67 | + replace: Dict[Variable, Variable], |
| 68 | + *, |
| 69 | + strict=True, |
| 70 | +) -> List[Variable]: |
| 71 | + """Replace variables in ``outputs`` by ``replace``. |
| 72 | +
|
| 73 | + Parameters |
| 74 | + ---------- |
| 75 | + outputs: Sequence[Variable] |
| 76 | + Output graph |
| 77 | + replace: Dict[Variable, Variable] |
| 78 | + Replace mapping |
| 79 | + strict: bool |
| 80 | + Raise an error if some replacements were not used |
| 81 | + return_unused: bool |
| 82 | + Return replacements that were not used |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + List[Variable] |
| 87 | + Output graph with subgraphs replaced |
| 88 | +
|
| 89 | + Raises |
| 90 | + ------ |
| 91 | + ValueError |
| 92 | + If some replacemens could not be applied and strict is True |
| 93 | + """ |
| 94 | + # collect minimum graph inputs which is required to compute outputs |
| 95 | + # and depend on replacements |
| 96 | + # additionally remove constants, they do not matter in clone get equiv |
| 97 | + conditions = [ |
| 98 | + c |
| 99 | + for c in truncated_graph_inputs(outputs, replace) |
| 100 | + if not isinstance(c, Constant) |
| 101 | + ] |
| 102 | + # for the function graph we need the clean graph where |
| 103 | + # inputs do not have owners |
| 104 | + # this is exactly the reason to clone conditions |
| 105 | + equiv = {c: c.clone(name=f"i-{i}") for i, c in enumerate(conditions)} |
| 106 | + # some replace keys may dissapear |
| 107 | + # the reason is they are outside the graph |
| 108 | + # clone the graph but preserve the equiv mapping |
| 109 | + fg = FunctionGraph( |
| 110 | + conditions, |
| 111 | + outputs, |
| 112 | + # clone_get_equiv kwargs |
| 113 | + copy_orphans=False, |
| 114 | + copy_inputs=False, |
| 115 | + memo=equiv, |
| 116 | + ) |
| 117 | + # replace the conditions back |
| 118 | + fg_replace = {equiv[c]: c for c in conditions} |
| 119 | + # add the replacements on top of input mappings |
| 120 | + fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv}) |
| 121 | + # replacements have to be done in reverse topological order so that nested |
| 122 | + # expressions get recursively replaced correctly |
| 123 | + |
| 124 | + # some replacements may be initially outside the graph |
| 125 | + # but later introduced by a replacement |
| 126 | + # So far FunctionGraph does these replacements inplace it is thus unsafe |
| 127 | + # apply them using fg.replace, it may change the original graph |
| 128 | + if strict: |
| 129 | + non_fg_replace = {r: v for r, v in replace.items() if r not in equiv} |
| 130 | + if non_fg_replace: |
| 131 | + raise ValueError(f"Some replacements were not used: {non_fg_replace}") |
| 132 | + toposort = fg.toposort() |
| 133 | + |
| 134 | + def toposort_key(fg: FunctionGraph, ts, pair): |
| 135 | + key, _ = pair |
| 136 | + if key.owner is not None: |
| 137 | + return ts.index(key.owner) |
| 138 | + else: |
| 139 | + if key in fg.variables: |
| 140 | + return -1 |
| 141 | + else: |
| 142 | + raise ValueError(f"{key} is not a part of graph") |
| 143 | + |
| 144 | + sorted_replacements = sorted( |
| 145 | + tuple(fg_replace.items()), |
| 146 | + # sort based on the fg toposort, if a variable has no owner, it goes first |
| 147 | + key=partial(toposort_key, fg, toposort), |
| 148 | + reverse=True, |
| 149 | + ) |
| 150 | + fg.replace_all(sorted_replacements, import_missing=True) |
| 151 | + return list(fg.outputs) |
0 commit comments