110110_logger = logging .getLogger ("pytensor.tensor.blas" )
111111
112112try :
113- import scipy .linalg .blas
113+ from scipy .linalg .blas import get_blas_funcs
114114
115115 have_fblas = True
116- try :
117- fblas = scipy .linalg .blas .fblas
118- except AttributeError :
119- # A change merged in Scipy development version on 2012-12-02 replaced
120- # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
121- # See http://github.com/scipy/scipy/pull/358
122- fblas = scipy .linalg .blas
123- _blas_gemv_fns = {
124- np .dtype ("float32" ): fblas .sgemv ,
125- np .dtype ("float64" ): fblas .dgemv ,
126- np .dtype ("complex64" ): fblas .cgemv ,
127- np .dtype ("complex128" ): fblas .zgemv ,
128- }
129116except ImportError as e :
130117 have_fblas = False
131118 # This is used in Gemv and ScipyGer. We use CGemv and CGer
@@ -146,18 +133,18 @@ def check_init_y():
146133 if check_init_y ._result is None :
147134 if not have_fblas :
148135 check_init_y ._result = False
149-
150- y = float ("NaN" ) * np .ones ((2 ,))
151- x = np .ones ((2 ,))
152- A = np .ones ((2 , 2 ))
153- gemv = _blas_gemv_fns [ y .dtype ]
154- gemv (1.0 , A .T , x , 0.0 , y , overwrite_y = True , trans = True )
155- check_init_y ._result = np .isnan (y ).any ()
136+ else :
137+ y = float ("NaN" ) * np .ones ((2 ,))
138+ x = np .ones ((2 ,))
139+ A = np .ones ((2 , 2 ))
140+ gemv = get_blas_funcs ( names = "gemv" , dtype = y .dtype )
141+ gemv (1.0 , A .T , x , 0.0 , y , overwrite_y = True , trans = True )
142+ check_init_y ._result = np .isnan (y ).any ()
156143
157144 return check_init_y ._result
158145
159146
160- check_init_y ._result = None
147+ check_init_y ._result = None # type: ignore
161148
162149
163150class Gemv (Op ):
@@ -210,14 +197,11 @@ def make_node(self, y, alpha, A, x, beta):
210197
211198 def perform (self , node , inputs , out_storage ):
212199 y , alpha , A , x , beta = inputs
213- if (
214- have_fblas
215- and y .shape [0 ] != 0
216- and x .shape [0 ] != 0
217- and y .dtype in _blas_gemv_fns
218- ):
219- gemv = _blas_gemv_fns [y .dtype ]
220-
200+ try :
201+ gemv = get_blas_funcs (names = "gemv" , dtype = y .dtype )
202+ except Exception :
203+ gemv = None
204+ if have_fblas and y .shape [0 ] != 0 and x .shape [0 ] != 0 and gemv is not None :
221205 if A .shape [0 ] != y .shape [0 ] or A .shape [1 ] != x .shape [0 ]:
222206 raise ValueError (
223207 "Incompatible shapes for gemv "
0 commit comments