11import logging
2+ import typing
23import warnings
34from typing import TYPE_CHECKING , Literal , Union
45
1213from pytensor .tensor import as_tensor_variable
1314from pytensor .tensor import basic as at
1415from pytensor .tensor import math as atm
16+ from pytensor .tensor .nlinalg import matrix_dot
1517from pytensor .tensor .shape import reshape
1618from pytensor .tensor .type import matrix , tensor , vector
1719from pytensor .tensor .var import TensorVariable
@@ -321,9 +323,6 @@ def L_op(self, inputs, outputs, output_gradients):
321323 return res
322324
323325
324- solvetriangular = SolveTriangular ()
325-
326-
327326def solve_triangular (
328327 a : TensorVariable ,
329328 b : TensorVariable ,
@@ -397,9 +396,6 @@ def perform(self, node, inputs, outputs):
397396 )
398397
399398
400- solve = Solve ()
401-
402-
403399def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
404400 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
405401
@@ -748,13 +744,9 @@ def grad(self, inputs, output_grads):
748744
749745
750746_solve_continuous_lyapunov = SolveContinuousLyapunov ()
751- _solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov ()
752-
753-
754- def iscomplexobj (x ):
755- type_ = x .type
756- dtype = type_ .dtype
757- return "complex" in dtype
747+ _solve_bilinear_direct_lyapunov = typing .cast (
748+ typing .Callable , BilinearSolveDiscreteLyapunov ()
749+ )
758750
759751
760752def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
767759 AA = kron (A_ , A_ )
768760
769761 X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
770- return reshape (X , Q_ .shape )
762+ return typing . cast ( TensorVariable , reshape (X , Q_ .shape ) )
771763
772764
773765def solve_discrete_lyapunov (
@@ -803,7 +795,7 @@ def solve_discrete_lyapunov(
803795 if method == "direct" :
804796 return _direct_solve_discrete_lyapunov (A , Q )
805797 if method == "bilinear" :
806- return _solve_bilinear_direct_lyapunov (A , Q )
798+ return typing . cast ( TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ) )
807799
808800
809801def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
823815
824816 """
825817
826- return _solve_continuous_lyapunov (A , Q )
818+ return typing .cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
819+
820+
821+ class SolveDiscreteARE (pt .Op ):
822+ __props__ = ("enforce_Q_symmetric" ,)
823+
824+ def __init__ (self , enforce_Q_symmetric = False ):
825+ self .enforce_Q_symmetric = enforce_Q_symmetric
826+
827+ def make_node (self , A , B , Q , R ):
828+ A = as_tensor_variable (A )
829+ B = as_tensor_variable (B )
830+ Q = as_tensor_variable (Q )
831+ R = as_tensor_variable (R )
832+
833+ out_dtype = pytensor .scalar .upcast (A .dtype , B .dtype , Q .dtype , R .dtype )
834+ X = pytensor .tensor .matrix (dtype = out_dtype )
835+
836+ return pytensor .graph .basic .Apply (self , [A , B , Q , R ], [X ])
837+
838+ def perform (self , node , inputs , output_storage ):
839+ A , B , Q , R = inputs
840+ X = output_storage [0 ]
841+
842+ if self .enforce_Q_symmetric :
843+ Q = 0.5 * (Q + Q .T )
844+
845+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (
846+ node .outputs [0 ].type .dtype
847+ )
848+
849+ def infer_shape (self , fgraph , node , shapes ):
850+ return [shapes [0 ]]
851+
852+ def grad (self , inputs , output_grads ):
853+ # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
854+ A , B , Q , R = inputs
855+
856+ (dX ,) = output_grads
857+ X = self (A , B , Q , R )
858+
859+ K_inner = R + pt .linalg .matrix_dot (B .T , X , B )
860+ K_inner_inv = pt .linalg .solve (K_inner , pt .eye (R .shape [0 ]))
861+ K = matrix_dot (K_inner_inv , B .T , X , A )
862+
863+ A_tilde = A - B .dot (K )
864+
865+ dX_symm = 0.5 * (dX + dX .T )
866+ S = solve_discrete_lyapunov (A_tilde , dX_symm ).astype (dX .type .dtype )
867+
868+ A_bar = 2 * matrix_dot (X , A_tilde , S )
869+ B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
870+ Q_bar = S
871+ R_bar = matrix_dot (K , S , K .T )
872+
873+ return [A_bar , B_bar , Q_bar , R_bar ]
874+
875+
876+ def solve_discrete_are (A , B , Q , R , enforce_Q_symmetric = False ) -> TensorVariable :
877+ """
878+ Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
879+
880+ Parameters
881+ ----------
882+ A: ArrayLike
883+ Square matrix of shape M x M
884+ B: ArrayLike
885+ Square matrix of shape M x M
886+ Q: ArrayLike
887+ Symmetric square matrix of shape M x M
888+ R: ArrayLike
889+ Square matrix of shape N x N
890+ enforce_Q_symmetric: bool
891+ If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
892+
893+ Returns
894+ -------
895+ X: pt.matrix
896+ Square matrix of shape M x M, representing the solution to the DARE
897+ """
898+
899+ return typing .cast (
900+ TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R )
901+ )
827902
828903
829904__all__ = [
@@ -832,4 +907,8 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
832907 "eigvalsh" ,
833908 "kron" ,
834909 "expm" ,
910+ "solve_discrete_lyapunov" ,
911+ "solve_continuous_lyapunov" ,
912+ "solve_discrete_are" ,
913+ "solve_triangular" ,
835914]
0 commit comments