Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import jax

from pytensor.link.jax.dispatch.basic import jax_funcify
Expand Down Expand Up @@ -39,13 +41,29 @@

@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower:
lower = True
assume_a = op.assume_a
lower = op.lower

if assume_a == "tridiagonal":
# jax.scipy.solve does not yet support tridiagonal matrices
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
def solve(a, b):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)

Check warning on line 54 in pytensor/link/jax/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/slinalg.py#L50-L54

Added lines #L50 - L54 were not covered by tests

else:
lower = False
if assume_a not in ("gen", "sym", "her", "pos"):
warnings.warn(

Check warning on line 58 in pytensor/link/jax/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/slinalg.py#L58

Added line #L58 was not covered by tests
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
UserWarning,
)
assume_a = "gen"

Check warning on line 63 in pytensor/link/jax/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/slinalg.py#L63

Added line #L63 was not covered by tests

def solve(a, b, lower=lower):
return jax.scipy.linalg.solve(a, b, lower=lower)
def solve(a, b):
return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)

return solve

Expand Down
29 changes: 17 additions & 12 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Callable

import numba
Expand Down Expand Up @@ -653,7 +654,7 @@

def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
Expand All @@ -664,7 +665,8 @@
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
Expand Down Expand Up @@ -740,8 +742,8 @@
)

if B_is_1d:
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
return B_copy, IPIV, int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)

Check warning on line 746 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L745-L746

Added lines #L745 - L746 were not covered by tests

return impl

Expand Down Expand Up @@ -770,7 +772,7 @@

N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("L"))
UPLO = val_to_int_ptr(ord("U"))

Check warning on line 775 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L775

Added line #L775 was not covered by tests
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
Expand Down Expand Up @@ -843,10 +845,10 @@
) -> np.ndarray:
_solve_check_input_shapes(A, B)

x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)

Check warning on line 848 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L848

Added line #L848 was not covered by tests
_solve_check(A.shape[-1], info)

rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))

Check warning on line 851 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L851

Added line #L851 was not covered by tests
_solve_check(A.shape[-1], info, True, rcond)

return x
Expand Down Expand Up @@ -1070,14 +1072,17 @@
elif assume_a == "sym":
solve_fn = _solve_symmetric
elif assume_a == "her":
raise NotImplementedError(
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
"please open an issue on github."
)
# We already ruled out complex inputs
solve_fn = _solve_symmetric

Check warning on line 1076 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L1076

Added line #L1076 was not covered by tests
elif assume_a == "pos":
solve_fn = _solve_psd
else:
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
warnings.warn(

Check warning on line 1080 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L1080

Added line #L1080 was not covered by tests
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.",
UserWarning,
)
solve_fn = _solve_gen

Check warning on line 1085 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L1085

Added line #L1085 was not covered by tests

@numba_basic.numba_njit(inline="always")
def solve(a, b):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4369,7 +4369,7 @@ def atleast_Nd(
atleast_3d = partial(atleast_Nd, n=3)


def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
"""Expand the shape of an array.

Insert a new axis that will appear at the `axis` position in the expanded
Expand Down
121 changes: 98 additions & 23 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape
Expand Down Expand Up @@ -260,10 +261,10 @@
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")

# Infer dtype by solving the most simple case with 1x1 matrices
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
out_arr = [[None]]
self.perform(None, inp_arr, out_arr)
o_dtype = out_arr[0][0].dtype
o_dtype = scipy_linalg.solve(
np.ones((1, 1), dtype=A.dtype),
np.ones((1,), dtype=b.dtype),
).dtype
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [A, b], [x])

Expand Down Expand Up @@ -315,7 +316,7 @@

b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
return min(b.ndim, 2) # By default, assume the core case is a matrix


class CholeskySolve(SolveBase):
Expand All @@ -332,6 +333,19 @@
kwargs.setdefault("lower", True)
super().__init__(**kwargs)

def make_node(self, *inputs):
# Allow base class to do input validation
super_apply = super().make_node(*inputs)
A, b = super_apply.inputs
[super_out] = super_apply.outputs
# The dtype of chol_solve does not match solve, which the base class checks
dtype = scipy_linalg.cho_solve(
(np.ones((1, 1), dtype=A.dtype), False),
np.ones((1,), dtype=b.dtype),
).dtype
out = tensor(dtype=dtype, shape=super_out.type.shape)
return Apply(self, [A, b], [out])

def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy_linalg.cho_solve(
Expand Down Expand Up @@ -499,8 +513,33 @@
)

def __init__(self, *, assume_a="gen", **kwargs):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
# Triangular and diagonal are handled outside of Solve
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]

assume_a = assume_a.lower()
# We use the old names as the different dispatches are more likely to support them
long_to_short = {
"general": "gen",
"symmetric": "sym",
"hermitian": "her",
"positive definite": "pos",
}
assume_a = long_to_short.get(assume_a, assume_a)

if assume_a not in valid_options:
raise ValueError(
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
)

if assume_a in ("tridiagonal", "banded"):
from scipy import __version__ as sp_version

if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
warnings.warn(

Check warning on line 538 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L538

Added line #L538 was not covered by tests
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
UserWarning,
)
assume_a = "gen"

Check warning on line 542 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L542

Added line #L542 was not covered by tests

super().__init__(**kwargs)
self.assume_a = assume_a
Expand Down Expand Up @@ -536,10 +575,12 @@
a,
b,
*,
assume_a="gen",
lower=False,
transposed=False,
check_finite=True,
lower: bool = False,
overwrite_a: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we decided not to expose this option to users?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't. The docstrings specify it's ignored and PyTensor will perform inplace if possible. This way the signature is the same for the user.

overwrite_b: bool = False,
check_finite: bool = True,
assume_a: str = "gen",
transposed: bool = False,
b_ndim: int | None = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
Expand All @@ -548,14 +589,19 @@
corresponding string to ``assume_a`` key chooses the dedicated solver.
The available options are

=================== ========
generic matrix 'gen'
symmetric 'sym'
hermitian 'her'
positive definite 'pos'
=================== ========
=================== ================================
diagonal 'diagonal'
tridiagonal 'tridiagonal'
banded 'banded'
upper triangular 'upper triangular'
lower triangular 'lower triangular'
symmetric 'symmetric' (or 'sym')
hermitian 'hermitian' (or 'her')
positive definite 'positive definite' (or 'pos')
general 'general' (or 'gen')
=================== ================================

If omitted, ``'gen'`` is the default structure.
If omitted, ``'general'`` is the default structure.

The datatype of the arrays define which solver is called regardless
of the values. In other words, even when the complex array entries have
Expand All @@ -568,23 +614,52 @@
Square input data
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, use only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
lower : bool, default False
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
If True, the calculation uses only the data in the lower triangle of `a`;
entries above the diagonal are ignored. If False (default), the
calculation uses only the data in the upper triangle of `a`; entries
below the diagonal are ignored.
overwrite_a : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
transposed: bool, default False
If True, solves the system A^T x = b. Default is False.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
"""
assume_a = assume_a.lower()

if assume_a in ("lower triangular", "upper triangular"):
lower = "lower" in assume_a
return solve_triangular(
a,
b,
lower=lower,
trans=transposed,
check_finite=check_finite,
b_ndim=b_ndim,
)

b_ndim = _default_b_ndim(b, b_ndim)

if assume_a == "diagonal":
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
if b_ndim == 1:
x = x.squeeze(-1)
return x

if transposed:
a = a.mT
lower = not lower
Expand Down
Loading