@@ -49,7 +49,7 @@ class Cholesky(Op):
4949
5050 __props__ = ("lower" , "destructive" , "on_error" )
5151
52- def __init__ (self , lower = True , on_error = "raise" ):
52+ def __init__ (self , * , lower = True , on_error = "raise" ):
5353 self .lower = lower
5454 self .destructive = False
5555 if on_error not in ("raise" , "nan" ):
@@ -127,77 +127,8 @@ def conjugate_solve_triangular(outer, inner):
127127 return [grad ]
128128
129129
130- cholesky = Cholesky ()
131-
132-
133- class CholeskySolve (Op ):
134- __props__ = ("lower" , "check_finite" )
135-
136- def __init__ (
137- self ,
138- lower = True ,
139- check_finite = True ,
140- ):
141- self .lower = lower
142- self .check_finite = check_finite
143-
144- def __repr__ (self ):
145- return "CholeskySolve{%s}" % str (self ._props ())
146-
147- def make_node (self , C , b ):
148- C = as_tensor_variable (C )
149- b = as_tensor_variable (b )
150- assert C .ndim == 2
151- assert b .ndim in (1 , 2 )
152-
153- # infer dtype by solving the most simple
154- # case with (1, 1) matrices
155- o_dtype = scipy .linalg .solve (
156- np .eye (1 ).astype (C .dtype ), np .eye (1 ).astype (b .dtype )
157- ).dtype
158- x = tensor (dtype = o_dtype , shape = b .type .shape )
159- return Apply (self , [C , b ], [x ])
160-
161- def perform (self , node , inputs , output_storage ):
162- C , b = inputs
163- rval = scipy .linalg .cho_solve (
164- (C , self .lower ),
165- b ,
166- check_finite = self .check_finite ,
167- )
168-
169- output_storage [0 ][0 ] = rval
170-
171- def infer_shape (self , fgraph , node , shapes ):
172- Cshape , Bshape = shapes
173- rows = Cshape [1 ]
174- if len (Bshape ) == 1 : # b is a Vector
175- return [(rows ,)]
176- else :
177- cols = Bshape [1 ] # b is a Matrix
178- return [(rows , cols )]
179-
180-
181- cho_solve = CholeskySolve ()
182-
183-
184- def cho_solve (c_and_lower , b , check_finite = True ):
185- """Solve the linear equations A x = b, given the Cholesky factorization of A.
186-
187- Parameters
188- ----------
189- (c, lower) : tuple, (array, bool)
190- Cholesky factorization of a, as given by cho_factor
191- b : array
192- Right-hand side
193- check_finite : bool, optional
194- Whether to check that the input matrices contain only finite numbers.
195- Disabling may give a performance gain, but may result in problems
196- (crashes, non-termination) if the inputs do contain infinities or NaNs.
197- """
198-
199- A , lower = c_and_lower
200- return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
130+ def cholesky (x , lower = True , on_error = "raise" ):
131+ return Cholesky (lower = lower , on_error = on_error )(x )
201132
202133
203134class SolveBase (Op ):
@@ -210,6 +141,7 @@ class SolveBase(Op):
210141
211142 def __init__ (
212143 self ,
144+ * ,
213145 lower = False ,
214146 check_finite = True ,
215147 ):
@@ -276,28 +208,56 @@ def L_op(self, inputs, outputs, output_gradients):
276208
277209 return [A_bar , b_bar ]
278210
279- def __repr__ (self ):
280- return f"{ type (self ).__name__ } { self ._props ()} "
211+
212+ class CholeskySolve (SolveBase ):
213+ def __init__ (self , ** kwargs ):
214+ kwargs .setdefault ("lower" , True )
215+ super ().__init__ (** kwargs )
216+
217+ def perform (self , node , inputs , output_storage ):
218+ C , b = inputs
219+ rval = scipy .linalg .cho_solve (
220+ (C , self .lower ),
221+ b ,
222+ check_finite = self .check_finite ,
223+ )
224+
225+ output_storage [0 ][0 ] = rval
226+
227+ def L_op (self , * args , ** kwargs ):
228+ raise NotImplementedError ()
229+
230+
231+ def cho_solve (c_and_lower , b , * , check_finite = True ):
232+ """Solve the linear equations A x = b, given the Cholesky factorization of A.
233+
234+ Parameters
235+ ----------
236+ (c, lower) : tuple, (array, bool)
237+ Cholesky factorization of a, as given by cho_factor
238+ b : array
239+ Right-hand side
240+ check_finite : bool, optional
241+ Whether to check that the input matrices contain only finite numbers.
242+ Disabling may give a performance gain, but may result in problems
243+ (crashes, non-termination) if the inputs do contain infinities or NaNs.
244+ """
245+ A , lower = c_and_lower
246+ return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
281247
282248
283249class SolveTriangular (SolveBase ):
284250 """Solve a system of linear equations."""
285251
286252 __props__ = (
287- "lower" ,
288253 "trans" ,
289254 "unit_diagonal" ,
255+ "lower" ,
290256 "check_finite" ,
291257 )
292258
293- def __init__ (
294- self ,
295- trans = 0 ,
296- lower = False ,
297- unit_diagonal = False ,
298- check_finite = True ,
299- ):
300- super ().__init__ (lower = lower , check_finite = check_finite )
259+ def __init__ (self , * , trans = 0 , unit_diagonal = False , ** kwargs ):
260+ super ().__init__ (** kwargs )
301261 self .trans = trans
302262 self .unit_diagonal = unit_diagonal
303263
@@ -326,6 +286,7 @@ def L_op(self, inputs, outputs, output_gradients):
326286def solve_triangular (
327287 a : TensorVariable ,
328288 b : TensorVariable ,
289+ * ,
329290 trans : Union [int , str ] = 0 ,
330291 lower : bool = False ,
331292 unit_diagonal : bool = False ,
@@ -373,16 +334,11 @@ class Solve(SolveBase):
373334 "check_finite" ,
374335 )
375336
376- def __init__ (
377- self ,
378- assume_a = "gen" ,
379- lower = False ,
380- check_finite = True ,
381- ):
337+ def __init__ (self , * , assume_a = "gen" , ** kwargs ):
382338 if assume_a not in ("gen" , "sym" , "her" , "pos" ):
383339 raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
384340
385- super ().__init__ (lower = lower , check_finite = check_finite )
341+ super ().__init__ (** kwargs )
386342 self .assume_a = assume_a
387343
388344 def perform (self , node , inputs , outputs ):
@@ -396,7 +352,7 @@ def perform(self, node, inputs, outputs):
396352 )
397353
398354
399- def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
355+ def solve (a , b , * , assume_a = "gen" , lower = False , check_finite = True ):
400356 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
401357
402358 If the data matrix is known to be a particular type then supplying the
0 commit comments