Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,26 @@
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)

# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1

# This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A)

Check warning on line 132 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L132

Added line #L132 was not covered by tests

if overwrite_b:
B_copy = B
if B_is_1d:
B_copy = np.expand_dims(B, -1)

Check warning on line 136 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L136

Added line #L136 was not covered by tests
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)

Check warning on line 140 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L140

Added line #L140 was not covered by tests
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)
B_copy = np.copy(np.expand_dims(B, -1))

Check warning on line 143 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L143

Added line #L143 was not covered by tests
else:
B_copy = _copy_to_fortran_order(B)

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)

NRHS = 1 if B_is_1d else int(B_copy.shape[-1])

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
Expand All @@ -155,7 +161,7 @@
DIAG,
N,
NRHS,
np.asfortranarray(A).T.view(w_type).ctypes,
A_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
Expand Down
43 changes: 43 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor.slinalg import SolveTriangular
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py

Expand Down Expand Up @@ -130,6 +131,48 @@ def A_func_pt(x):
)


@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
def test_solve_triangular_overwrite_b_correct(overwrite_b):
# Regression test for issue #1233

rng = np.random.default_rng(utt.fetch_seed())
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
a_test_py = np.tril(a_test_py)
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))

# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
a_test_nb = a_test_py.copy(order="F")
b_test_nb = b_test_py.copy(order="F")

op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
b_ndim=2,
overwrite_b=overwrite_b,
)

a_pt = pt.matrix("a", shape=(3, 3))
b_pt = pt.matrix("b", shape=(3, 2))
out = op(a_pt, b_pt)

py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")

x_py = py_fn(a_test_py, b_test_py)
x_nb = numba_fn(a_test_nb, b_test_nb)

np.testing.assert_allclose(
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
)
np.testing.assert_allclose(b_test_py, b_test_nb)

if overwrite_b:
np.testing.assert_allclose(b_test_py, x_py)
np.testing.assert_allclose(b_test_nb, x_nb)


@pytest.mark.parametrize("value", [np.nan, np.inf])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
Expand Down