11from functools import partial
2- from typing import (
3- Collection ,
4- Dict ,
5- Iterable ,
6- List ,
7- Optional ,
8- Sequence ,
9- Tuple ,
10- Union ,
11- cast ,
12- )
13-
14- from pytensor .graph .basic import Constant , Variable , truncated_graph_inputs
2+ from typing import Iterable , Optional , Sequence , Union , cast , overload
3+
4+ from pytensor .graph .basic import Apply , Constant , Variable , truncated_graph_inputs
155from pytensor .graph .fg import FunctionGraph
166
177
8+ ReplaceTypes = Union [Iterable [tuple [Variable , Variable ]], dict [Variable , Variable ]]
9+
10+
11+ def _format_replace (replace : Optional [ReplaceTypes ] = None ) -> dict [Variable , Variable ]:
12+ items : dict [Variable , Variable ]
13+ if isinstance (replace , dict ):
14+ # PyLance has issues with type resolution
15+ items = cast (dict [Variable , Variable ], replace )
16+ elif isinstance (replace , Iterable ):
17+ items = dict (replace )
18+ elif replace is None :
19+ items = {}
20+ else :
21+ raise ValueError (
22+ "replace is neither a dictionary, list, "
23+ f"tuple or None ! The value provided is { replace } ,"
24+ f"of type { type (replace )} "
25+ )
26+ return items
27+
28+
29+ @overload
30+ def clone_replace (
31+ output : Sequence [Variable ],
32+ replace : Optional [ReplaceTypes ] = None ,
33+ ** rebuild_kwds ,
34+ ) -> list [Variable ]:
35+ ...
36+
37+
38+ @overload
1839def clone_replace (
19- output : Collection [ Variable ] ,
40+ output : Variable ,
2041 replace : Optional [
21- Union [Iterable [Tuple [Variable , Variable ]], Dict [Variable , Variable ]]
42+ Union [Iterable [tuple [Variable , Variable ]], dict [Variable , Variable ]]
2243 ] = None ,
2344 ** rebuild_kwds ,
24- ) -> List [Variable ]:
45+ ) -> Variable :
46+ ...
47+
48+
49+ def clone_replace (
50+ output : Union [Sequence [Variable ], Variable ],
51+ replace : Optional [ReplaceTypes ] = None ,
52+ ** rebuild_kwds ,
53+ ) -> Union [list [Variable ], Variable ]:
2554 """Clone a graph and replace subgraphs within it.
2655
2756 It returns a copy of the initial subgraph with the corresponding
@@ -39,40 +68,49 @@ def clone_replace(
3968 """
4069 from pytensor .compile .function .pfunc import rebuild_collect_shared
4170
42- items : Union [List [Tuple [Variable , Variable ]], Tuple [Tuple [Variable , Variable ], ...]]
43- if isinstance (replace , dict ):
44- items = list (replace .items ())
45- elif isinstance (replace , (list , tuple )):
46- items = replace
47- elif replace is None :
48- items = []
49- else :
50- raise ValueError (
51- "replace is neither a dictionary, list, "
52- f"tuple or None ! The value provided is { replace } ,"
53- f"of type { type (replace )} "
54- )
71+ items = list (_format_replace (replace ).items ())
72+
5573 tmp_replace = [(x , x .type ()) for x , y in items ]
5674 new_replace = [(x , y ) for ((_ , x ), (_ , y )) in zip (tmp_replace , items )]
5775 _ , _outs , _ = rebuild_collect_shared (output , [], tmp_replace , [], ** rebuild_kwds )
5876
5977 # TODO Explain why we call it twice ?!
6078 _ , outs , _ = rebuild_collect_shared (_outs , [], new_replace , [], ** rebuild_kwds )
6179
62- return cast ( List [ Variable ], outs )
80+ return outs
6381
6482
83+ @overload
84+ def graph_replace (
85+ outputs : Variable ,
86+ replace : Optional [ReplaceTypes ] = None ,
87+ * ,
88+ strict = True ,
89+ ) -> Variable :
90+ ...
91+
92+
93+ @overload
6594def graph_replace (
6695 outputs : Sequence [Variable ],
67- replace : Dict [Variable , Variable ],
96+ replace : Optional [ReplaceTypes ] = None ,
97+ * ,
98+ strict = True ,
99+ ) -> list [Variable ]:
100+ ...
101+
102+
103+ def graph_replace (
104+ outputs : Union [Sequence [Variable ], Variable ],
105+ replace : Optional [ReplaceTypes ] = None ,
68106 * ,
69107 strict = True ,
70- ) -> List [ Variable ]:
108+ ) -> Union [ list [ Variable ], Variable ]:
71109 """Replace variables in ``outputs`` by ``replace``.
72110
73111 Parameters
74112 ----------
75- outputs: Sequence[Variable]
113+ outputs: Union[ Sequence[Variable], Variable]
76114 Output graph
77115 replace: Dict[Variable, Variable]
78116 Replace mapping
@@ -83,20 +121,26 @@ def graph_replace(
83121
84122 Returns
85123 -------
86- List[Variable]
87- Output graph with subgraphs replaced
124+ Union[Variable, List[Variable] ]
125+ Output graph with subgraphs replaced, see function overload for the exact type
88126
89127 Raises
90128 ------
91129 ValueError
92- If some replacemens could not be applied and strict is True
130+ If some replacements could not be applied and strict is True
93131 """
132+ as_list = False
133+ if not isinstance (outputs , Sequence ):
134+ outputs = [outputs ]
135+ else :
136+ as_list = True
137+ replace_dict = _format_replace (replace )
94138 # collect minimum graph inputs which is required to compute outputs
95139 # and depend on replacements
96140 # additionally remove constants, they do not matter in clone get equiv
97141 conditions = [
98142 c
99- for c in truncated_graph_inputs (outputs , replace )
143+ for c in truncated_graph_inputs (outputs , replace_dict )
100144 if not isinstance (c , Constant )
101145 ]
102146 # for the function graph we need the clean graph where
@@ -117,7 +161,7 @@ def graph_replace(
117161 # replace the conditions back
118162 fg_replace = {equiv [c ]: c for c in conditions }
119163 # add the replacements on top of input mappings
120- fg_replace .update ({equiv [r ]: v for r , v in replace .items () if r in equiv })
164+ fg_replace .update ({equiv [r ]: v for r , v in replace_dict .items () if r in equiv })
121165 # replacements have to be done in reverse topological order so that nested
122166 # expressions get recursively replaced correctly
123167
@@ -126,12 +170,14 @@ def graph_replace(
126170 # So far FunctionGraph does these replacements inplace it is thus unsafe
127171 # apply them using fg.replace, it may change the original graph
128172 if strict :
129- non_fg_replace = {r : v for r , v in replace .items () if r not in equiv }
173+ non_fg_replace = {r : v for r , v in replace_dict .items () if r not in equiv }
130174 if non_fg_replace :
131175 raise ValueError (f"Some replacements were not used: { non_fg_replace } " )
132176 toposort = fg .toposort ()
133177
134- def toposort_key (fg : FunctionGraph , ts , pair ):
178+ def toposort_key (
179+ fg : FunctionGraph , ts : list [Apply ], pair : tuple [Variable , Variable ]
180+ ) -> int :
135181 key , _ = pair
136182 if key .owner is not None :
137183 return ts .index (key .owner )
@@ -148,4 +194,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
148194 reverse = True ,
149195 )
150196 fg .replace_all (sorted_replacements , import_missing = True )
151- return list (fg .outputs )
197+ if as_list :
198+ return list (fg .outputs )
199+ else :
200+ return fg .outputs [0 ]
0 commit comments