| 
2 | 2 | 
 
  | 
3 | 3 | import numpy as np  | 
4 | 4 | import pytest  | 
 | 5 | +import scipy  | 
5 | 6 | 
 
  | 
6 | 7 | import pytensor  | 
7 | 8 | import pytensor.tensor as pt  | 
8 | 9 | from pytensor import config  | 
9 |  | -from pytensor.compile import SharedVariable  | 
10 |  | -from pytensor.graph import Constant, FunctionGraph  | 
 | 10 | +from pytensor.graph import FunctionGraph  | 
11 | 11 | from tests.link.numba.test_basic import compare_numba_and_py  | 
12 |  | -from tests.tensor.test_extra_ops import set_test_value  | 
13 | 12 | 
 
  | 
14 | 13 | 
 
  | 
15 | 14 | numba = pytest.importorskip("numba")  | 
@@ -109,23 +108,25 @@ def test_solve_triangular_raises_on_nan_inf(value):  | 
109 | 108 | 
 
  | 
110 | 109 | 
 
  | 
111 | 110 | @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])  | 
112 |  | -def test_numba_Cholesky(lower):  | 
113 |  | -    x = set_test_value(  | 
114 |  | -        pt.tensor(dtype=config.floatX, shape=(3, 3)),  | 
115 |  | -        (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),  | 
116 |  | -    )  | 
 | 111 | +@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])  | 
 | 112 | +def test_numba_Cholesky(lower, trans):  | 
 | 113 | +    cov = pt.matrix("cov")  | 
117 | 114 | 
 
  | 
118 |  | -    g = pt.linalg.cholesky(x, lower=lower)  | 
119 |  | -    g_fg = FunctionGraph(outputs=[g])  | 
 | 115 | +    if trans:  | 
 | 116 | +        cov_ = cov.T  | 
 | 117 | +    else:  | 
 | 118 | +        cov_ = cov  | 
 | 119 | +    chol = pt.linalg.cholesky(cov_, lower=lower)  | 
120 | 120 | 
 
  | 
121 |  | -    compare_numba_and_py(  | 
122 |  | -        g_fg,  | 
123 |  | -        [  | 
124 |  | -            i.tag.test_value  | 
125 |  | -            for i in g_fg.inputs  | 
126 |  | -            if not isinstance(i, SharedVariable | Constant)  | 
127 |  | -        ],  | 
128 |  | -    )  | 
 | 121 | +    fg = FunctionGraph(outputs=[chol])  | 
 | 122 | + | 
 | 123 | +    x = np.array([0.1, 0.2, 0.3])  | 
 | 124 | +    val = np.eye(3) + x[None, :] + x[:, None]  | 
 | 125 | + | 
 | 126 | +    compare_numba_and_py(fg, [val])  | 
 | 127 | + | 
 | 128 | +    func = pytensor.function([cov], chol, mode="NUMBA")  | 
 | 129 | +    np.testing.assert_allclose(func(val), scipy.linalg.cholesky(val, lower=lower))  | 
129 | 130 | 
 
  | 
130 | 131 | 
 
  | 
131 | 132 | def test_numba_Cholesky_raises_on_nan_input():  | 
 | 
0 commit comments