Skip to content

Commit d25e2a1

Browse files
committed
fix(numba): Allow None axis is CumOp
1 parent 71c2361 commit d25e2a1

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numba
44
import numpy as np
55
from numba.misc.special import literal_unroll
6+
from numpy.core.multiarray import normalize_axis_index # type: ignore
67

78
from pytensor import config
89
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -37,10 +38,8 @@ def numba_funcify_CumOp(op, node, **kwargs):
3738
mode = op.mode
3839
ndim = node.outputs[0].ndim
3940

40-
if axis < 0:
41-
axis = ndim + axis
42-
if axis < 0 or axis >= ndim:
43-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
41+
if axis is not None:
42+
axis = normalize_axis_index(axis, ndim)
4443

4544
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
4645
reaxis_first_inv = tuple(np.argsort(reaxis_first))
@@ -52,6 +51,12 @@ def numba_funcify_CumOp(op, node, **kwargs):
5251
def cumop(x):
5352
return np.cumsum(x)
5453

54+
elif axis is None:
55+
56+
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
57+
def cumop(x):
58+
return np.cumsum(x.ravel())
59+
5560
else:
5661

5762
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@@ -76,6 +81,12 @@ def cumop(x):
7681
def cumop(x):
7782
return np.cumprod(x)
7883

84+
elif axis is None:
85+
86+
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
87+
def cumop(x):
88+
return np.cumprod(x.ravel())
89+
7990
else:
8091

8192
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)

0 commit comments

Comments
 (0)