1- from copy import copy
21from itertools import chain
3- from typing import Optional , Sequence , Tuple , cast
2+ from typing import Optional , Sequence , Tuple
43
54from pytensor .compile import rebuild_collect_shared
65from pytensor .graph import Constant , FunctionGraph , Variable , clone
7- from pytensor .graph .rewriting .basic import MergeOptimizer
8- from pytensor .scalar .basic import ScalarInnerGraphOp , ScalarOp , as_scalar
6+ from pytensor .scalar .basic import ScalarInnerGraphOp , as_scalar
97
108
119class ScalarLoop (ScalarInnerGraphOp ):
@@ -62,44 +60,38 @@ def __init__(
6260 if not len (init ) == len (update ):
6361 raise ValueError ("An update must be given for each init variable" )
6462 if until :
65- inputs , (* outputs , until ) = clone ([* init , * constant ], [* update , until ])
66- self .outputs = copy ([* outputs , until ])
63+ inputs , outputs = clone ([* init , * constant ], [* update , until ])
6764 else :
6865 inputs , outputs = clone ([* init , * constant ], update )
69- self .outputs = copy (outputs )
70- self .inputs = copy (inputs )
7166
7267 self .is_while = bool (until )
73- self .inputs_type = tuple (input .type for input in inputs )
74- self .outputs_type = tuple (output .type for output in outputs )
75- if self .is_while :
76- self .outputs_type = self .outputs_type + (cast (Variable , until ).type ,)
77- self .nin = len (inputs ) + 1 # n_steps is not part of the inner graph
78- self .nout = len (outputs ) + (1 if self .is_while else 0 )
68+ self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs )
69+ self ._validate_updates (self .inputs , self .outputs )
70+
71+ self .inputs_type = tuple (input .type for input in self .inputs )
72+ self .outputs_type = tuple (output .type for output in self .outputs )
73+ self .nin = len (self .inputs ) + 1 # n_steps is not part of the inner graph
74+ self .nout = len (self .outputs )
7975 self .name = name
80- self . _validate_fgraph ( FunctionGraph ( self . inputs , self . outputs , clone = False ))
76+
8177 super ().__init__ ()
8278
8379 def output_types (self , input_types ):
8480 return self .outputs_type
8581
86- def _validate_fgraph (self , fgraph : FunctionGraph ) -> None :
87- for node in fgraph .apply_nodes :
88- if not isinstance (node .op , ScalarOp ):
89- raise TypeError (
90- "The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes"
91- )
92-
93- init = fgraph .inputs
94- update = fgraph .outputs
95-
82+ def _validate_updates (
83+ self , inputs : Sequence [Variable ], outputs : Sequence [Variable ]
84+ ) -> None :
85+ init = inputs
86+ update : Sequence [Variable ]
9687 if self .is_while :
97- * update , until = update
88+ * update , until = outputs
9889 if not until .type .dtype == "bool" :
9990 raise TypeError (
10091 f"Until condition must be boolean, got { until } ({ until .type .dtype } )"
10192 )
102-
93+ else :
94+ update = outputs
10395 for i , u in zip (init , update ):
10496 if i .type != u .type :
10597 raise TypeError (
@@ -116,28 +108,9 @@ def _validate_fgraph(self, fgraph: FunctionGraph) -> None:
116108 def fgraph (self ):
117109 if hasattr (self , "_fgraph" ):
118110 return self ._fgraph
119-
111+ # fgraph cannot be a property of the base class because it messes up with C caching.
112+ # We also need a `FunctionGraph(clone=True)` (default) according to an old comment
120113 fgraph = FunctionGraph (self .inputs , self .outputs )
121- # TODO: We could convert to TensorVariable, optimize graph,
122- # and then convert back to ScalarVariable.
123- # This would introduce rewrites like `log(1 + x) -> log1p`.
124- MergeOptimizer ().rewrite (fgraph )
125- self ._validate_fgraph (fgraph )
126-
127- # Clone identical outputs that have been merged
128- if len (set (fgraph .outputs )) != len (self .outputs ):
129- old_outputs = fgraph .outputs
130- new_outputs = []
131- for output in old_outputs :
132- if output not in new_outputs :
133- new_outputs .append (output )
134- else :
135- node = output .owner
136- output_idx = node .outputs .index (output )
137- new_output = node .clone ().outputs [output_idx ]
138- new_outputs .append (new_output )
139- fgraph = FunctionGraph (fgraph .inputs , new_outputs , clone = False )
140-
141114 self ._fgraph = fgraph
142115 return self ._fgraph
143116
0 commit comments