Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import warnings
from typing import cast

import numba
import numpy as np

from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import (
Bartlett,
CumOp,
Expand All @@ -30,21 +33,22 @@ def bartlett(x):


@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
ndim = cast(TensorVariable, node.outputs[0]).ndim

if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")

reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reaxis_first_inv = tuple(np.argsort(reaxis_first))
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reaxis_first_inv = tuple(np.argsort(reaxis_first))

if mode == "add":
if ndim == 1:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
Expand All @@ -68,7 +72,7 @@ def cumop(x):
return res.transpose(reaxis_first_inv)

else:
if ndim == 1:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def cumop(x):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Collection
from typing import Iterable, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union

import numpy as np
from numpy.core.multiarray import normalize_axis_index
Expand Down Expand Up @@ -291,7 +291,7 @@ class CumOp(COp):
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)

def __init__(self, axis=None, mode="add"):
def __init__(self, axis: Optional[int] = None, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
self.axis = axis
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def __iter__(self):
)

@property
def ndim(self):
def ndim(self) -> int:
"""The rank of this tensor."""
return self.type.ndim

Expand Down
14 changes: 14 additions & 0 deletions tests/link/numba/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def test_Bartlett(val):
1,
"add",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None,
"add",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
Expand All @@ -81,6 +88,13 @@ def test_Bartlett(val):
1,
"mul",
),
(
set_test_value(
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
Expand Down