Skip to content

Commit 57e5aa4

Browse files
committed
Fix Numba pos solve condition number calculation
1 parent d59922d commit 57e5aa4

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def _posv(
885885
overwrite_b: bool,
886886
check_finite: bool,
887887
transposed: bool,
888-
) -> tuple[np.ndarray, int]:
888+
) -> tuple[np.ndarray, np.ndarray, int]:
889889
"""
890890
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
891891
"""
@@ -902,7 +902,8 @@ def posv_impl(
902902
check_finite: bool,
903903
transposed: bool,
904904
) -> Callable[
905-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
905+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
906+
tuple[np.ndarray, np.ndarray, int],
906907
]:
907908
ensure_lapack()
908909
_check_scipy_linalg_matrix(A, "solve")
@@ -963,8 +964,9 @@ def impl(
963964
)
964965

965966
if B_is_1d:
966-
return B_copy[..., 0], int_ptr_to_val(INFO)
967-
return B_copy, int_ptr_to_val(INFO)
967+
B_copy = B_copy[..., 0]
968+
969+
return A_copy, B_copy, int_ptr_to_val(INFO)
968970

969971
return impl
970972

@@ -1065,10 +1067,12 @@ def impl(
10651067
) -> np.ndarray:
10661068
_solve_check_input_shapes(A, B)
10671069

1068-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1070+
lu, x, info = _posv(
1071+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1072+
)
10691073
_solve_check(A.shape[-1], info)
10701074

1071-
rcond, info = _pocon(x, _xlange(A))
1075+
rcond, info = _pocon(lu, _xlange(A))
10721076
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10731077

10741078
return x

0 commit comments

Comments
 (0)