Skip to content

Commit 4d9b0a2

Browse files
committed
fix(numba): cholesky did not set off-diag entries to zero
1 parent 6857517 commit 4d9b0a2

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,15 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
346346
INFO,
347347
)
348348

349+
if lower:
350+
for j in range(1, _N):
351+
for i in range(j):
352+
A_copy[i, j] = 0.0
353+
else:
354+
for j in range(_N):
355+
for i in range(j + 1, _N):
356+
A_copy[i, j] = 0.0
357+
349358
return A_copy, int_ptr_to_val(INFO)
350359

351360
return impl

tests/link/numba/test_slinalg.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
import numpy as np
44
import pytest
5+
import scipy
56

67
import pytensor
78
import pytensor.tensor as pt
89
from pytensor import config
9-
from pytensor.compile import SharedVariable
10-
from pytensor.graph import Constant, FunctionGraph
10+
from pytensor.graph import FunctionGraph
1111
from tests.link.numba.test_basic import compare_numba_and_py
12-
from tests.tensor.test_extra_ops import set_test_value
1312

1413

1514
numba = pytest.importorskip("numba")
@@ -109,23 +108,25 @@ def test_solve_triangular_raises_on_nan_inf(value):
109108

110109

111110
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
112-
def test_numba_Cholesky(lower):
113-
x = set_test_value(
114-
pt.tensor(dtype=config.floatX, shape=(3, 3)),
115-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
116-
)
111+
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
112+
def test_numba_Cholesky(lower, trans):
113+
cov = pt.matrix("cov")
117114

118-
g = pt.linalg.cholesky(x, lower=lower)
119-
g_fg = FunctionGraph(outputs=[g])
115+
if trans:
116+
cov_ = cov.T
117+
else:
118+
cov_ = cov
119+
chol = pt.linalg.cholesky(cov_, lower=lower)
120120

121-
compare_numba_and_py(
122-
g_fg,
123-
[
124-
i.tag.test_value
125-
for i in g_fg.inputs
126-
if not isinstance(i, SharedVariable | Constant)
127-
],
128-
)
121+
fg = FunctionGraph(outputs=[chol])
122+
123+
x = np.array([0.1, 0.2, 0.3])
124+
val = np.eye(3) + x[None, :] + x[:, None]
125+
126+
compare_numba_and_py(fg, [val])
127+
128+
func = pytensor.function([cov], chol, mode="NUMBA")
129+
np.testing.assert_allclose(func(val), scipy.linalg.cholesky(val, lower=lower))
129130

130131

131132
def test_numba_Cholesky_raises_on_nan_input():

0 commit comments

Comments
 (0)