1- import typing
21from functools import partial
3- from typing import Callable , Tuple
2+ from typing import Tuple
43
54import numpy as np
65
@@ -271,7 +270,6 @@ class Eig(Op):
271270
272271 """
273272
274- _numop = staticmethod (np .linalg .eig )
275273 __props__ : Tuple [str , ...] = ()
276274
277275 def make_node (self , x ):
@@ -284,7 +282,7 @@ def make_node(self, x):
284282 def perform (self , node , inputs , outputs ):
285283 (x ,) = inputs
286284 (w , v ) = outputs
287- w [0 ], v [0 ] = (z .astype (x .dtype ) for z in self . _numop (x ))
285+ w [0 ], v [0 ] = (z .astype (x .dtype ) for z in np . linalg . eig (x ))
288286
289287 def infer_shape (self , fgraph , node , shapes ):
290288 n = shapes [0 ][0 ]
@@ -300,7 +298,6 @@ class Eigh(Eig):
300298
301299 """
302300
303- _numop = typing .cast (Callable , staticmethod (np .linalg .eigh ))
304301 __props__ = ("UPLO" ,)
305302
306303 def __init__ (self , UPLO = "L" ):
@@ -315,15 +312,15 @@ def make_node(self, x):
315312 # LAPACK. Rather than trying to reproduce the (rather
316313 # involved) logic, we just probe linalg.eigh with a trivial
317314 # input.
318- w_dtype = self . _numop ([[np .dtype (x .dtype ).type ()]])[0 ].dtype .name
315+ w_dtype = np . linalg . eigh ([[np .dtype (x .dtype ).type ()]])[0 ].dtype .name
319316 w = vector (dtype = w_dtype )
320317 v = matrix (dtype = w_dtype )
321318 return Apply (self , [x ], [w , v ])
322319
323320 def perform (self , node , inputs , outputs ):
324321 (x ,) = inputs
325322 (w , v ) = outputs
326- w [0 ], v [0 ] = self . _numop (x , self .UPLO )
323+ w [0 ], v [0 ] = np . linalg . eigh (x , self .UPLO )
327324
328325 def grad (self , inputs , g_outputs ):
329326 r"""The gradient function should return
@@ -446,7 +443,6 @@ class QRFull(Op):
446443
447444 """
448445
449- _numop = staticmethod (np .linalg .qr )
450446 __props__ = ("mode" ,)
451447
452448 def __init__ (self , mode ):
@@ -478,7 +474,7 @@ def make_node(self, x):
478474 def perform (self , node , inputs , outputs ):
479475 (x ,) = inputs
480476 assert x .ndim == 2 , "The input of qr function should be a matrix."
481- res = self . _numop (x , self .mode )
477+ res = np . linalg . qr (x , self .mode )
482478 if self .mode != "r" :
483479 outputs [0 ][0 ], outputs [1 ][0 ] = res
484480 else :
@@ -547,7 +543,6 @@ class SVD(Op):
547543 """
548544
549545 # See doc in the docstring of the function just after this class.
550- _numop = staticmethod (np .linalg .svd )
551546 __props__ = ("full_matrices" , "compute_uv" )
552547
553548 def __init__ (self , full_matrices = True , compute_uv = True ):
@@ -575,10 +570,10 @@ def perform(self, node, inputs, outputs):
575570 assert x .ndim == 2 , "The input of svd function should be a matrix."
576571 if self .compute_uv :
577572 u , s , vt = outputs
578- u [0 ], s [0 ], vt [0 ] = self . _numop (x , self .full_matrices , self .compute_uv )
573+ u [0 ], s [0 ], vt [0 ] = np . linalg . svd (x , self .full_matrices , self .compute_uv )
579574 else :
580575 (s ,) = outputs
581- s [0 ] = self . _numop (x , self .full_matrices , self .compute_uv )
576+ s [0 ] = np . linalg . svd (x , self .full_matrices , self .compute_uv )
582577
583578 def infer_shape (self , fgraph , node , shapes ):
584579 (x_shape ,) = shapes
@@ -730,7 +725,6 @@ class TensorInv(Op):
730725 PyTensor utilization of numpy.linalg.tensorinv;
731726 """
732727
733- _numop = staticmethod (np .linalg .tensorinv )
734728 __props__ = ("ind" ,)
735729
736730 def __init__ (self , ind = 2 ):
@@ -744,7 +738,7 @@ def make_node(self, a):
744738 def perform (self , node , inputs , outputs ):
745739 (a ,) = inputs
746740 (x ,) = outputs
747- x [0 ] = self . _numop (a , self .ind )
741+ x [0 ] = np . linalg . tensorinv (a , self .ind )
748742
749743 def infer_shape (self , fgraph , node , shapes ):
750744 sp = shapes [0 ][self .ind :] + shapes [0 ][: self .ind ]
@@ -790,7 +784,6 @@ class TensorSolve(Op):
790784
791785 """
792786
793- _numop = staticmethod (np .linalg .tensorsolve )
794787 __props__ = ("axes" ,)
795788
796789 def __init__ (self , axes = None ):
@@ -809,7 +802,7 @@ def perform(self, node, inputs, outputs):
809802 b ,
810803 ) = inputs
811804 (x ,) = outputs
812- x [0 ] = self . _numop (a , b , self .axes )
805+ x [0 ] = np . linalg . tensorsolve (a , b , self .axes )
813806
814807
815808def tensorsolve (a , b , axes = None ):
0 commit comments