-
Notifications
You must be signed in to change notification settings - Fork 145
Allow more specialized linalg.solve
assume_a
cases
#1273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from pytensor.tensor import TensorLike, as_tensor_variable | ||
from pytensor.tensor import basic as ptb | ||
from pytensor.tensor import math as ptm | ||
from pytensor.tensor.basic import diagonal | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.nlinalg import kron, matrix_dot | ||
from pytensor.tensor.shape import reshape | ||
|
@@ -260,10 +261,10 @@ | |
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") | ||
|
||
# Infer dtype by solving the most simple case with 1x1 matrices | ||
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] | ||
out_arr = [[None]] | ||
self.perform(None, inp_arr, out_arr) | ||
o_dtype = out_arr[0][0].dtype | ||
o_dtype = scipy_linalg.solve( | ||
np.ones((1, 1), dtype=A.dtype), | ||
np.ones((1,), dtype=b.dtype), | ||
).dtype | ||
x = tensor(dtype=o_dtype, shape=b.type.shape) | ||
return Apply(self, [A, b], [x]) | ||
|
||
|
@@ -315,7 +316,7 @@ | |
|
||
b = as_tensor_variable(b) | ||
if b_ndim is None: | ||
return min(b.ndim, 2) # By default assume the core case is a matrix | ||
return min(b.ndim, 2) # By default, assume the core case is a matrix | ||
|
||
|
||
class CholeskySolve(SolveBase): | ||
|
@@ -332,6 +333,19 @@ | |
kwargs.setdefault("lower", True) | ||
super().__init__(**kwargs) | ||
|
||
def make_node(self, *inputs): | ||
# Allow base class to do input validation | ||
super_apply = super().make_node(*inputs) | ||
A, b = super_apply.inputs | ||
[super_out] = super_apply.outputs | ||
# The dtype of chol_solve does not match solve, which the base class checks | ||
dtype = scipy_linalg.cho_solve( | ||
(np.ones((1, 1), dtype=A.dtype), False), | ||
np.ones((1,), dtype=b.dtype), | ||
).dtype | ||
out = tensor(dtype=dtype, shape=super_out.type.shape) | ||
return Apply(self, [A, b], [out]) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
C, b = inputs | ||
rval = scipy_linalg.cho_solve( | ||
|
@@ -499,8 +513,33 @@ | |
) | ||
|
||
def __init__(self, *, assume_a="gen", **kwargs): | ||
if assume_a not in ("gen", "sym", "her", "pos"): | ||
raise ValueError(f"{assume_a} is not a recognized matrix structure") | ||
# Triangular and diagonal are handled outside of Solve | ||
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"] | ||
|
||
assume_a = assume_a.lower() | ||
# We use the old names as the different dispatches are more likely to support them | ||
long_to_short = { | ||
"general": "gen", | ||
"symmetric": "sym", | ||
"hermitian": "her", | ||
"positive definite": "pos", | ||
} | ||
assume_a = long_to_short.get(assume_a, assume_a) | ||
|
||
if assume_a not in valid_options: | ||
raise ValueError( | ||
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}" | ||
) | ||
|
||
if assume_a in ("tridiagonal", "banded"): | ||
from scipy import __version__ as sp_version | ||
|
||
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15): | ||
warnings.warn( | ||
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.", | ||
UserWarning, | ||
) | ||
assume_a = "gen" | ||
|
||
super().__init__(**kwargs) | ||
self.assume_a = assume_a | ||
|
@@ -536,10 +575,12 @@ | |
a, | ||
b, | ||
*, | ||
assume_a="gen", | ||
lower=False, | ||
transposed=False, | ||
check_finite=True, | ||
lower: bool = False, | ||
overwrite_a: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we decided not to expose this option to users? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't. The docstrings specify it's ignored and PyTensor will perform inplace if possible. This way the signature is the same for the user. |
||
overwrite_b: bool = False, | ||
check_finite: bool = True, | ||
assume_a: str = "gen", | ||
transposed: bool = False, | ||
b_ndim: int | None = None, | ||
): | ||
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. | ||
|
@@ -548,14 +589,19 @@ | |
corresponding string to ``assume_a`` key chooses the dedicated solver. | ||
The available options are | ||
|
||
=================== ======== | ||
generic matrix 'gen' | ||
symmetric 'sym' | ||
hermitian 'her' | ||
positive definite 'pos' | ||
=================== ======== | ||
=================== ================================ | ||
diagonal 'diagonal' | ||
tridiagonal 'tridiagonal' | ||
banded 'banded' | ||
upper triangular 'upper triangular' | ||
lower triangular 'lower triangular' | ||
symmetric 'symmetric' (or 'sym') | ||
hermitian 'hermitian' (or 'her') | ||
positive definite 'positive definite' (or 'pos') | ||
general 'general' (or 'gen') | ||
=================== ================================ | ||
|
||
If omitted, ``'gen'`` is the default structure. | ||
If omitted, ``'general'`` is the default structure. | ||
|
||
The datatype of the arrays define which solver is called regardless | ||
of the values. In other words, even when the complex array entries have | ||
|
@@ -568,23 +614,52 @@ | |
Square input data | ||
b : (..., N, NRHS) array_like | ||
Input data for the right hand side. | ||
lower : bool, optional | ||
If True, use only the data contained in the lower triangle of `a`. Default | ||
is to use upper triangle. (ignored for ``'gen'``) | ||
transposed: bool, optional | ||
If True, solves the system A^T x = b. Default is False. | ||
lower : bool, default False | ||
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``. | ||
If True, the calculation uses only the data in the lower triangle of `a`; | ||
entries above the diagonal are ignored. If False (default), the | ||
calculation uses only the data in the upper triangle of `a`; entries | ||
below the diagonal are ignored. | ||
overwrite_a : bool | ||
Unused by PyTensor. PyTensor will always perform the operation in-place if possible. | ||
overwrite_b : bool | ||
Unused by PyTensor. PyTensor will always perform the operation in-place if possible. | ||
check_finite : bool, optional | ||
Whether to check that the input matrices contain only finite numbers. | ||
Disabling may give a performance gain, but may result in problems | ||
(crashes, non-termination) if the inputs do contain infinities or NaNs. | ||
assume_a : str, optional | ||
Valid entries are explained above. | ||
transposed: bool, default False | ||
If True, solves the system A^T x = b. Default is False. | ||
b_ndim : int | ||
Whether the core case of b is a vector (1) or matrix (2). | ||
This will influence how batched dimensions are interpreted. | ||
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1. | ||
""" | ||
assume_a = assume_a.lower() | ||
|
||
if assume_a in ("lower triangular", "upper triangular"): | ||
lower = "lower" in assume_a | ||
return solve_triangular( | ||
a, | ||
b, | ||
lower=lower, | ||
trans=transposed, | ||
check_finite=check_finite, | ||
b_ndim=b_ndim, | ||
) | ||
|
||
b_ndim = _default_b_ndim(b, b_ndim) | ||
|
||
if assume_a == "diagonal": | ||
a_diagonal = diagonal(a, axis1=-2, axis2=-1) | ||
b_transposed = b[None, :] if b_ndim == 1 else b.mT | ||
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT | ||
if b_ndim == 1: | ||
x = x.squeeze(-1) | ||
return x | ||
|
||
if transposed: | ||
a = a.mT | ||
lower = not lower | ||
|
Uh oh!
There was an error while loading. Please reload this page.