File tree Expand file tree Collapse file tree 4 files changed +107
-0
lines changed Expand file tree Collapse file tree 4 files changed +107
-0
lines changed Original file line number Diff line number Diff line change 1818 MatrixInverse ,
1919 MatrixPinv ,
2020 QRFull ,
21+ SLogDet ,
2122)
2223
2324
@@ -58,6 +59,25 @@ def det(x):
5859 return det
5960
6061
62+ @numba_funcify .register (SLogDet )
63+ def numba_funcify_SLogDet (op , node , ** kwargs ):
64+
65+ out_dtype_1 = node .outputs [0 ].type .numpy_dtype
66+ out_dtype_2 = node .outputs [1 ].type .numpy_dtype
67+
68+ inputs_cast = int_to_float_fn (node .inputs , out_dtype_1 )
69+
70+ @numba_basic .numba_njit
71+ def slogdet (x ):
72+ sign , det = np .linalg .slogdet (inputs_cast (x ))
73+ return (
74+ numba_basic .direct_cast (sign , out_dtype_1 ),
75+ numba_basic .direct_cast (det , out_dtype_2 ),
76+ )
77+
78+ return slogdet
79+
80+
6181@numba_funcify .register (Eig )
6282def numba_funcify_Eig (op , node , ** kwargs ):
6383
Original file line number Diff line number Diff line change @@ -231,6 +231,45 @@ def __str__(self):
231231det = Det ()
232232
233233
234+ class SLogDet (Op ):
235+ """
236+ Compute sign and log determinant of the matrix. Input should be a square matrix.
237+ """
238+
239+ __props__ = ()
240+
241+ def make_node (self , x ):
242+ x = as_tensor_variable (x )
243+ assert x .ndim == 2
244+ s = scalar (dtype = x .dtype )
245+ d = scalar (dtype = x .dtype )
246+ return Apply (self , [x ], [s , d ])
247+
248+ def perform (self , node , inputs , outputs ):
249+ (x ,) = inputs
250+ (s , d ) = outputs
251+ try :
252+ s [0 ], d [0 ] = (z .astype (x .dtype ) for z in np .linalg .slogdet (x ))
253+ except Exception :
254+ print ("Failed to compute determinant" , x )
255+ raise
256+
257+ def grad (self , inputs , g_outputs ):
258+ (gz ,) = g_outputs
259+ (x ,) = inputs
260+ sign , det = self (x )
261+ return [gz * sign * np .exp (det ) * matrix_inverse (x ).T ]
262+
263+ def infer_shape (self , fgraph , node , shapes ):
264+ return [(), ()]
265+
266+ def __str__ (self ):
267+ return "SLogDet"
268+
269+
270+ slogdet = SLogDet ()
271+
272+
234273class Eig (Op ):
235274 """
236275 Compute the eigenvalues and right eigenvectors of a square array.
Original file line number Diff line number Diff line change @@ -179,6 +179,41 @@ def test_Det(x, exc):
179179 )
180180
181181
182+ @pytest .mark .parametrize (
183+ "x, exc" ,
184+ [
185+ (
186+ set_test_value (
187+ at .dmatrix (),
188+ (lambda x : x .T .dot (x ))(rng .random (size = (3 , 3 )).astype ("float64" )),
189+ ),
190+ None ,
191+ ),
192+ (
193+ set_test_value (
194+ at .lmatrix (),
195+ (lambda x : x .T .dot (x ))(rng .poisson (size = (3 , 3 )).astype ("int64" )),
196+ ),
197+ None ,
198+ ),
199+ ],
200+ )
201+ def test_SLogDet (x , exc ):
202+ g = nlinalg .SLogDet ()(x )
203+ g_fg = FunctionGraph (outputs = g )
204+
205+ cm = contextlib .suppress () if exc is None else pytest .warns (exc )
206+ with cm :
207+ compare_numba_and_py (
208+ g_fg ,
209+ [
210+ i .tag .test_value
211+ for i in g_fg .inputs
212+ if not isinstance (i , (SharedVariable , Constant ))
213+ ],
214+ )
215+
216+
182217# We were seeing some weird results in CI where the following two almost
183218# sign-swapped results were being return from Numba and Python, respectively.
184219# The issue might be related to https://github.com/numba/numba/issues/4519.
Original file line number Diff line number Diff line change 2424 norm ,
2525 pinv ,
2626 qr ,
27+ slogdet ,
2728 svd ,
2829 tensorinv ,
2930 tensorsolve ,
@@ -266,6 +267,18 @@ def test_det():
266267 assert np .allclose (np .linalg .det (r ), f (r ))
267268
268269
270+ def test_slogdet ():
271+ rng = np .random .default_rng (utt .fetch_seed ())
272+
273+ r = rng .standard_normal ((5 , 5 )).astype (config .floatX )
274+ x = matrix ()
275+ f = pytensor .function ([x ], slogdet (x ))
276+ f_sign , f_det = f (r )
277+ sign , det = np .linalg .slogdet (r )
278+ assert np .equal (sign , f_sign )
279+ assert np .allclose (det , f_det )
280+
281+
269282def test_det_grad ():
270283 rng = np .random .default_rng (utt .fetch_seed ())
271284
You can’t perform that action at this time.
0 commit comments