1515from pytensor .tensor import TensorLike , as_tensor_variable
1616from pytensor .tensor import basic as ptb
1717from pytensor .tensor import math as ptm
18+ from pytensor .tensor .basic import diagonal
1819from pytensor .tensor .blockwise import Blockwise
1920from pytensor .tensor .nlinalg import kron , matrix_dot
2021from pytensor .tensor .shape import reshape
@@ -260,10 +261,10 @@ def make_node(self, A, b):
260261 raise ValueError (f"`b` must have { self .b_ndim } dims; got { b .type } instead." )
261262
262263 # Infer dtype by solving the most simple case with 1x1 matrices
263- inp_arr = [ np . eye ( 1 ). astype ( A . dtype ), np . eye ( 1 ). astype ( b . dtype )]
264- out_arr = [[ None ]]
265- self . perform ( None , inp_arr , out_arr )
266- o_dtype = out_arr [ 0 ][ 0 ] .dtype
264+ o_dtype = scipy_linalg . solve (
265+ np . ones (( 1 , 1 ), dtype = A . dtype ),
266+ np . ones (( 1 ,), dtype = b . dtype ),
267+ ) .dtype
267268 x = tensor (dtype = o_dtype , shape = b .type .shape )
268269 return Apply (self , [A , b ], [x ])
269270
@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
315316
316317 b = as_tensor_variable (b )
317318 if b_ndim is None :
318- return min (b .ndim , 2 ) # By default assume the core case is a matrix
319+ return min (b .ndim , 2 ) # By default, assume the core case is a matrix
319320
320321
321322class CholeskySolve (SolveBase ):
@@ -332,6 +333,19 @@ def __init__(self, **kwargs):
332333 kwargs .setdefault ("lower" , True )
333334 super ().__init__ (** kwargs )
334335
336+ def make_node (self , * inputs ):
337+ # Allow base class to do input validation
338+ super_apply = super ().make_node (* inputs )
339+ A , b = super_apply .inputs
340+ [super_out ] = super_apply .outputs
341+ # The dtype of chol_solve does not match solve, which the base class checks
342+ dtype = scipy_linalg .cho_solve (
343+ (np .ones ((1 , 1 ), dtype = A .dtype ), False ),
344+ np .ones ((1 ,), dtype = b .dtype ),
345+ ).dtype
346+ out = tensor (dtype = dtype , shape = super_out .type .shape )
347+ return Apply (self , [A , b ], [out ])
348+
335349 def perform (self , node , inputs , output_storage ):
336350 C , b = inputs
337351 rval = scipy_linalg .cho_solve (
@@ -499,8 +513,33 @@ class Solve(SolveBase):
499513 )
500514
501515 def __init__ (self , * , assume_a = "gen" , ** kwargs ):
502- if assume_a not in ("gen" , "sym" , "her" , "pos" ):
503- raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
516+ # Triangular and diagonal are handled outside of Solve
517+ valid_options = ["gen" , "sym" , "her" , "pos" , "tridiagonal" , "banded" ]
518+
519+ assume_a = assume_a .lower ()
520+ # We use the old names as the different dispatches are more likely to support them
521+ long_to_short = {
522+ "general" : "gen" ,
523+ "symmetric" : "sym" ,
524+ "hermitian" : "her" ,
525+ "positive definite" : "pos" ,
526+ }
527+ assume_a = long_to_short .get (assume_a , assume_a )
528+
529+ if assume_a not in valid_options :
530+ raise ValueError (
531+ f"Invalid assume_a: { assume_a } . It must be one of { valid_options } or { list (long_to_short .keys ())} "
532+ )
533+
534+ if assume_a in ("tridiagonal" , "banded" ):
535+ from scipy import __version__ as sp_version
536+
537+ if tuple (map (int , sp_version .split ("." )[:- 1 ])) < (1 , 15 ):
538+ warnings .warn (
539+ f"assume_a={ assume_a } requires scipy>=1.5.0. Defaulting to assume_a='gen'." ,
540+ UserWarning ,
541+ )
542+ assume_a = "gen"
504543
505544 super ().__init__ (** kwargs )
506545 self .assume_a = assume_a
@@ -536,10 +575,12 @@ def solve(
536575 a ,
537576 b ,
538577 * ,
539- assume_a = "gen" ,
540- lower = False ,
541- transposed = False ,
542- check_finite = True ,
578+ lower : bool = False ,
579+ overwrite_a : bool = False ,
580+ overwrite_b : bool = False ,
581+ check_finite : bool = True ,
582+ assume_a : str = "gen" ,
583+ transposed : bool = False ,
543584 b_ndim : int | None = None ,
544585):
545586 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +589,19 @@ def solve(
548589 corresponding string to ``assume_a`` key chooses the dedicated solver.
549590 The available options are
550591
551- =================== ========
552- generic matrix 'gen'
553- symmetric 'sym'
554- hermitian 'her'
555- positive definite 'pos'
556- =================== ========
592+ =================== ================================
593+ diagonal 'diagonal'
594+ tridiagonal 'tridiagonal'
595+ banded 'banded'
596+ upper triangular 'upper triangular'
597+ lower triangular 'lower triangular'
598+ symmetric 'symmetric' (or 'sym')
599+ hermitian 'hermitian' (or 'her')
600+ positive definite 'positive definite' (or 'pos')
601+ general 'general' (or 'gen')
602+ =================== ================================
557603
558- If omitted, ``'gen '`` is the default structure.
604+ If omitted, ``'general '`` is the default structure.
559605
560606 The datatype of the arrays define which solver is called regardless
561607 of the values. In other words, even when the complex array entries have
@@ -568,23 +614,52 @@ def solve(
568614 Square input data
569615 b : (..., N, NRHS) array_like
570616 Input data for the right hand side.
571- lower : bool, optional
572- If True, use only the data contained in the lower triangle of `a`. Default
573- is to use upper triangle. (ignored for ``'gen'``)
574- transposed: bool, optional
575- If True, solves the system A^T x = b. Default is False.
617+ lower : bool, default False
618+ Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
619+ If True, the calculation uses only the data in the lower triangle of `a`;
620+ entries above the diagonal are ignored. If False (default), the
621+ calculation uses only the data in the upper triangle of `a`; entries
622+ below the diagonal are ignored.
623+ overwrite_a : bool
624+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
625+ overwrite_b : bool
626+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
576627 check_finite : bool, optional
577628 Whether to check that the input matrices contain only finite numbers.
578629 Disabling may give a performance gain, but may result in problems
579630 (crashes, non-termination) if the inputs do contain infinities or NaNs.
580631 assume_a : str, optional
581632 Valid entries are explained above.
633+ transposed: bool, default False
634+ If True, solves the system A^T x = b. Default is False.
582635 b_ndim : int
583636 Whether the core case of b is a vector (1) or matrix (2).
584637 This will influence how batched dimensions are interpreted.
638+ By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585639 """
640+ assume_a = assume_a .lower ()
641+
642+ if assume_a in ("lower triangular" , "upper triangular" ):
643+ lower = "lower" in assume_a
644+ return solve_triangular (
645+ a ,
646+ b ,
647+ lower = lower ,
648+ trans = transposed ,
649+ check_finite = check_finite ,
650+ b_ndim = b_ndim ,
651+ )
652+
586653 b_ndim = _default_b_ndim (b , b_ndim )
587654
655+ if assume_a == "diagonal" :
656+ a_diagonal = diagonal (a , axis1 = - 2 , axis2 = - 1 )
657+ b_transposed = b [None , :] if b_ndim == 1 else b .mT
658+ x = (b_transposed / pt .expand_dims (a_diagonal , - 2 )).mT
659+ if b_ndim == 1 :
660+ x = x .squeeze (- 1 )
661+ return x
662+
588663 if transposed :
589664 a = a .mT
590665 lower = not lower
0 commit comments