|
3 | 3 | import traceback
|
4 | 4 | from abc import ABCMeta
|
5 | 5 | from io import StringIO
|
6 |
| -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union |
| 6 | +from typing import ( |
| 7 | + TYPE_CHECKING, |
| 8 | + Any, |
| 9 | + Dict, |
| 10 | + List, |
| 11 | + Optional, |
| 12 | + Sequence, |
| 13 | + Tuple, |
| 14 | + TypeVar, |
| 15 | + Union, |
| 16 | +) |
7 | 17 |
|
8 | 18 |
|
9 | 19 | if TYPE_CHECKING:
|
@@ -420,3 +430,49 @@ def toposort(prereqs_d):
|
420 | 430 | "some orderings contain invalid elements."
|
421 | 431 | )
|
422 | 432 | return seq
|
| 433 | + |
| 434 | + |
| 435 | +def graph_replace( |
| 436 | + outputs: Sequence["Variable"], replace: Dict["Variable", "Variable"] |
| 437 | +) -> List["Variable"]: |
| 438 | + from pytensor.graph.basic import condition_subset |
| 439 | + from pytensor.graph.fg import FunctionGraph |
| 440 | + |
| 441 | + # collect minimum graph inputs which is required to compute outputs |
| 442 | + # and depend on replacements |
| 443 | + conditions = list(condition_subset(outputs, replace)) |
| 444 | + # for the function graph we need the clean graph where |
| 445 | + # inputs do not have owners |
| 446 | + # this is exactly the reason to clone conditions |
| 447 | + equiv = {c: c.clone() for c in conditions} |
| 448 | + # some replace keys may dissapear |
| 449 | + # the reason is they are inside the graph |
| 450 | + # and depend on some vars in conditions |
| 451 | + # but we need to keep references to make replacements |
| 452 | + # we clone the replace keys to get them |
| 453 | + equiv.update({r: r.clone() for r in replace}) |
| 454 | + # clone the graph but preserve the equiv mapping |
| 455 | + fg = FunctionGraph( |
| 456 | + conditions, |
| 457 | + outputs, |
| 458 | + # clone_get_equiv kwargs |
| 459 | + copy_orphans=False, |
| 460 | + copy_inputs=False, |
| 461 | + memo=equiv, |
| 462 | + ) |
| 463 | + # replace the conditions back |
| 464 | + fg_replace = {equiv[c]: c for c in conditions} |
| 465 | + # add the replacements on top of input mappings |
| 466 | + fg_replace.update({equiv[r]: v for r, v in replace.items()}) |
| 467 | + # replacements have to be done in reverse topological order so that nested |
| 468 | + # expressions get recursively replaced correctly |
| 469 | + toposort = fg.toposort() |
| 470 | + sorted_replacements = sorted( |
| 471 | + tuple(fg_replace.items()), |
| 472 | + # sort based on the fg toposort, if a variable has no owner, it goes first |
| 473 | + key=lambda pair: (toposort.index(pair[0].owner) if pair[0].owner else -1), |
| 474 | + reverse=True, |
| 475 | + ) |
| 476 | + fg.replace_all(sorted_replacements, import_missing=True) |
| 477 | + |
| 478 | + return list(fg.outputs) |
0 commit comments