@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
213213 return _vectorize_node (op , node , * batched_inputs )
214214
215215
216+ @overload
217+ def vectorize (
218+ outputs : Variable ,
219+ replace : Mapping [Variable , Variable ],
220+ ) -> Variable :
221+ ...
222+
223+
224+ @overload
216225def vectorize (
217- outputs : Sequence [Variable ], vectorize : Mapping [Variable , Variable ]
226+ outputs : Sequence [Variable ],
227+ replace : Mapping [Variable , Variable ],
218228) -> Sequence [Variable ]:
229+ ...
230+
231+
232+ def vectorize (
233+ outputs : Union [Variable , Sequence [Variable ]],
234+ replace : Mapping [Variable , Variable ],
235+ ) -> Union [Variable , Sequence [Variable ]]:
219236 """Vectorize outputs graph given mapping from old variables to expanded counterparts version.
220237
221238 Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
@@ -235,20 +252,44 @@ def vectorize(
235252
236253 # Vectorized graph
237254 new_x = pt.matrix("new_x")
238- [ new_y] = vectorize([y], {x: new_x})
255+ new_y = vectorize(y, replace= {x: new_x})
239256
240257 fn = pytensor.function([new_x], new_y)
241258 fn([[0, 1, 2], [2, 1, 0]])
242259 # array([[0.09003057, 0.24472847, 0.66524096],
243260 # [0.66524096, 0.24472847, 0.09003057]])
244261
262+
263+ .. code-block:: python
264+
265+ import pytensor
266+ import pytensor.tensor as pt
267+
268+ from pytensor.graph import vectorize
269+
270+ # Original graph
271+ x = pt.vector("x")
272+ y1 = x[0]
273+ y2 = x[-1]
274+
275+ # Vectorized graph
276+ new_x = pt.matrix("new_x")
277+ [new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
278+
279+ fn = pytensor.function([new_x], [new_y1, new_y2])
280+ fn([[-10, 0, 10], [-11, 0, 11]])
281+ # [array([-10., -11.]), array([10., 11.])]
282+
245283 """
246- # Avoid circular import
284+ if isinstance (outputs , Sequence ):
285+ seq_outputs = outputs
286+ else :
287+ seq_outputs = [outputs ]
247288
248- inputs = truncated_graph_inputs (outputs , ancestors_to_include = vectorize .keys ())
249- new_inputs = [vectorize .get (inp , inp ) for inp in inputs ]
289+ inputs = truncated_graph_inputs (seq_outputs , ancestors_to_include = replace .keys ())
290+ new_inputs = [replace .get (inp , inp ) for inp in inputs ]
250291
251- def transform (var ) :
292+ def transform (var : Variable ) -> Variable :
252293 if var in inputs :
253294 return new_inputs [inputs .index (var )]
254295
@@ -257,7 +298,13 @@ def transform(var):
257298 batched_node = vectorize_node (node , * batched_inputs )
258299 batched_var = batched_node .outputs [var .owner .outputs .index (var )]
259300
260- return batched_var
301+ return cast ( Variable , batched_var )
261302
262303 # TODO: MergeOptimization or node caching?
263- return [transform (out ) for out in outputs ]
304+ seq_vect_outputs = [transform (out ) for out in seq_outputs ]
305+
306+ if isinstance (outputs , Sequence ):
307+ return seq_vect_outputs
308+ else :
309+ [vect_output ] = seq_vect_outputs
310+ return vect_output
0 commit comments