11import re
2- from functools import singledispatch
32from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
43
54import numpy as np
98from pytensor .graph .basic import Apply , Constant , Variable
109from pytensor .graph .null_type import NullType
1110from pytensor .graph .op import Op
11+ from pytensor .graph .replace import _vectorize_node , vectorize
1212from pytensor .tensor import as_tensor_variable
1313from pytensor .tensor .shape import shape_padleft
1414from pytensor .tensor .type import continuous_dtypes , discrete_dtypes , tensor
@@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str:
7272 return f"{ inputs_sig } ->{ outputs_sig } "
7373
7474
75- @singledispatch
76- def _vectorize_node (op : Op , node : Apply , * bached_inputs ) -> Apply :
75+ @_vectorize_node . register ( Op )
76+ def vectorize_node_fallback (op : Op , node : Apply , * bached_inputs ) -> Apply :
7777 if hasattr (op , "gufunc_signature" ):
7878 signature = op .gufunc_signature
7979 else :
@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
8383 return cast (Apply , Blockwise (op , signature = signature ).make_node (* bached_inputs ))
8484
8585
86- def vectorize_node (node : Apply , * batched_inputs ) -> Apply :
87- """Returns vectorized version of node with new batched inputs."""
88- op = node .op
89- return _vectorize_node (op , node , * batched_inputs )
90-
91-
9286class Blockwise (Op ):
9387 """Generalizes a core `Op` to work with batched dimensions.
9488
@@ -279,42 +273,18 @@ def as_core(t, core_t):
279273
280274 core_igrads = self .core_op .L_op (core_inputs , core_outputs , core_ograds )
281275
282- batch_ndims = self ._batch_ndim_from_outputs (outputs )
283-
284- def transform (var ):
285- # From a graph of ScalarOps, make a graph of Broadcast ops.
286- if isinstance (var .type , (NullType , DisconnectedType )):
287- return var
288- if var in core_inputs :
289- return inputs [core_inputs .index (var )]
290- if var in core_outputs :
291- return outputs [core_outputs .index (var )]
292- if var in core_ograds :
293- return ograds [core_ograds .index (var )]
294-
295- node = var .owner
296-
297- # The gradient contains a constant, which may be responsible for broadcasting
298- if node is None :
299- if batch_ndims :
300- var = shape_padleft (var , batch_ndims )
301- return var
302-
303- batched_inputs = [transform (inp ) for inp in node .inputs ]
304- batched_node = vectorize_node (node , * batched_inputs )
305- batched_var = batched_node .outputs [var .owner .outputs .index (var )]
306-
307- return batched_var
308-
309- ret = []
310- for core_igrad , ipt in zip (core_igrads , inputs ):
311- # Undefined gradient
312- if core_igrad is None :
313- ret .append (None )
314- else :
315- ret .append (transform (core_igrad ))
276+ igrads = vectorize (
277+ [core_igrad for core_igrad in core_igrads if core_igrad is not None ],
278+ vectorize = dict (
279+ zip (core_inputs + core_outputs + core_ograds , inputs + outputs + ograds )
280+ ),
281+ )
316282
317- return ret
283+ igrads_iter = iter (igrads )
284+ return [
285+ None if core_igrad is None else next (igrads_iter )
286+ for core_igrad in core_igrads
287+ ]
318288
319289 def L_op (self , inputs , outs , ograds ):
320290 from pytensor .tensor .math import sum as pt_sum
0 commit comments