1010from pytensor .tensor import basic as at
1111from pytensor .tensor import math as tm
1212from pytensor .tensor .basic import as_tensor_variable , extract_diag
13+ from pytensor .tensor .blockwise import Blockwise
1314from pytensor .tensor .type import dvector , lscalar , matrix , scalar , vector
1415
1516
1617class MatrixPinv (Op ):
1718 __props__ = ("hermitian" ,)
19+ gufunc_signature = "(m,n)->(n,m)"
1820
1921 def __init__ (self , hermitian ):
2022 self .hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
7577 solve op.
7678
7779 """
78- return MatrixPinv (hermitian = hermitian )(x )
80+ return Blockwise ( MatrixPinv (hermitian = hermitian ) )(x )
7981
8082
8183class MatrixInverse (Op ):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
9395 """
9496
9597 __props__ = ()
98+ gufunc_signature = "(m,m)->(m,m)"
99+ gufunc_spec = ("numpy.linalg.inv" , 1 , 1 )
96100
97101 def __init__ (self ):
98102 pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
150154 return shapes
151155
152156
153- inv = matrix_inverse = MatrixInverse ()
157+ inv = matrix_inverse = Blockwise ( MatrixInverse () )
154158
155159
156160def matrix_dot (* args ):
@@ -181,6 +185,8 @@ class Det(Op):
181185 """
182186
183187 __props__ = ()
188+ gufunc_signature = "(m,m)->()"
189+ gufunc_spec = ("numpy.linalg.det" , 1 , 1 )
184190
185191 def make_node (self , x ):
186192 x = as_tensor_variable (x )
@@ -209,7 +215,7 @@ def __str__(self):
209215 return "Det"
210216
211217
212- det = Det ()
218+ det = Blockwise ( Det () )
213219
214220
215221class SLogDet (Op ):
@@ -218,6 +224,8 @@ class SLogDet(Op):
218224 """
219225
220226 __props__ = ()
227+ gufunc_signature = "(m, m)->(),()"
228+ gufunc_spec = ("numpy.linalg.slogdet" , 1 , 2 )
221229
222230 def make_node (self , x ):
223231 x = as_tensor_variable (x )
@@ -242,7 +250,7 @@ def __str__(self):
242250 return "SLogDet"
243251
244252
245- slogdet = SLogDet ()
253+ slogdet = Blockwise ( SLogDet () )
246254
247255
248256class Eig (Op ):
@@ -252,6 +260,8 @@ class Eig(Op):
252260 """
253261
254262 __props__ : Tuple [str , ...] = ()
263+ gufunc_signature = "(m,m)->(m),(m,m)"
264+ gufunc_spec = ("numpy.linalg.eig" , 1 , 2 )
255265
256266 def make_node (self , x ):
257267 x = as_tensor_variable (x )
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
270280 return [(n ,), (n , n )]
271281
272282
273- eig = Eig ()
283+ eig = Blockwise ( Eig () )
274284
275285
276286class Eigh (Eig ):
0 commit comments