|
22 | 22 | from pytensor.tensor import elemwise_cgen as cgen |
23 | 23 | from pytensor.tensor import get_vector_length |
24 | 24 | from pytensor.tensor.basic import _get_vector_length, as_tensor_variable |
| 25 | +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed |
25 | 26 | from pytensor.tensor.type import ( |
26 | 27 | TensorType, |
27 | 28 | continuous_dtypes, |
28 | 29 | discrete_dtypes, |
29 | 30 | float_dtypes, |
30 | 31 | lvector, |
31 | 32 | ) |
| 33 | +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string |
32 | 34 | from pytensor.tensor.variable import TensorVariable |
33 | 35 | from pytensor.utils import uniq |
34 | 36 |
|
@@ -232,7 +234,7 @@ def __str__(self): |
232 | 234 | return f"Transpose{{axes={self.shuffle}}}" |
233 | 235 | return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" |
234 | 236 |
|
235 | | - def perform(self, node, inp, out, params): |
| 237 | + def perform(self, node, inp, out, params=None): |
236 | 238 | (res,) = inp |
237 | 239 | (storage,) = out |
238 | 240 |
|
@@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs): |
429 | 431 | # of all inputs in parallel... the all() gives us each output |
430 | 432 | # broadcastable bit in turn. |
431 | 433 |
|
432 | | - def get_most_specialized_shape(shapes): |
433 | | - shapes = set(shapes) |
434 | | - # All shapes are the same |
435 | | - if len(shapes) == 1: |
436 | | - return tuple(shapes)[0] |
437 | | - |
438 | | - # Only valid indeterminate case |
439 | | - if shapes == {None, 1}: |
440 | | - return None |
441 | | - |
442 | | - shapes.discard(1) |
443 | | - shapes.discard(None) |
444 | | - if len(shapes) > 1: |
445 | | - raise ValueError |
446 | | - return tuple(shapes)[0] |
447 | | - |
448 | 434 | # it is multiplied by nout because Elemwise supports multiple outputs |
449 | 435 | # (nout of them) |
450 | 436 | try: |
451 | 437 | out_shapes = [ |
452 | 438 | [ |
453 | | - get_most_specialized_shape(shape) |
| 439 | + broadcast_static_dim_lengths(shape) |
454 | 440 | for shape in zip(*[inp.type.shape for inp in inputs]) |
455 | 441 | ] |
456 | 442 | ] * shadow.nout |
@@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): |
665 | 651 | impl = "c" |
666 | 652 |
|
667 | 653 | if getattr(self, "nfunc_spec", None) and impl != "c": |
668 | | - self.nfunc = getattr(np, self.nfunc_spec[0], None) |
669 | | - if self.nfunc is None: |
670 | | - # Not inside NumPy. So probably another package like scipy. |
671 | | - symb = self.nfunc_spec[0].split(".") |
672 | | - for idx in range(1, len(self.nfunc_spec[0])): |
673 | | - try: |
674 | | - module = __import__(".".join(symb[:idx])) |
675 | | - except ImportError: |
676 | | - break |
677 | | - for sub in symb[1:]: |
678 | | - try: |
679 | | - module = getattr(module, sub) |
680 | | - except AttributeError: |
681 | | - module = None |
682 | | - break |
683 | | - self.nfunc = module |
| 654 | + self.nfunc = import_func_from_string(self.nfunc_spec[0]) |
684 | 655 |
|
685 | 656 | if ( |
686 | 657 | (len(node.inputs) + len(node.outputs)) <= 32 |
@@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var): |
1768 | 1739 | return get_vector_length(var.owner.inputs[0]) |
1769 | 1740 |
|
1770 | 1741 | raise ValueError(f"Length of {var} cannot be determined") |
| 1742 | + |
| 1743 | + |
| 1744 | +_vectorize_node.register(Elemwise, vectorize_not_needed) |
| 1745 | + |
| 1746 | + |
| 1747 | +@_vectorize_node.register(DimShuffle) |
| 1748 | +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: |
| 1749 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1750 | + if not batched_ndims: |
| 1751 | + return node.op.make_node(x) |
| 1752 | + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable |
| 1753 | + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) |
| 1754 | + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) |
| 1755 | + new_order = list(range(batched_ndims)) + [ |
| 1756 | + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order |
| 1757 | + ] |
| 1758 | + return DimShuffle(input_broadcastable, new_order).make_node(x) |
| 1759 | + |
| 1760 | + |
| 1761 | +@_vectorize_node.register(CAReduce) |
| 1762 | +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: |
| 1763 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1764 | + if not batched_ndims: |
| 1765 | + return node.op.make_node(x) |
| 1766 | + axes = op.axis |
| 1767 | + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) |
| 1768 | + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) |
| 1769 | + if axes is None: |
| 1770 | + axes = list(range(node.inputs[0].type.ndim)) |
| 1771 | + else: |
| 1772 | + axes = list(axes) |
| 1773 | + new_axes = [axis + batched_ndims for axis in axes] |
| 1774 | + new_op = op.clone(axis=new_axes) |
| 1775 | + return new_op.make_node(x) |
0 commit comments