Skip to content

Commit e1d432d

Browse files
committed
Use get_blas_funcs instead of storing fortran objects in a dict
1 parent e180927 commit e1d432d

File tree

3 files changed

+17
-46
lines changed

3 files changed

+17
-46
lines changed

pytensor/tensor/blas.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,9 @@
110110
_logger = logging.getLogger("pytensor.tensor.blas")
111111

112112
try:
113-
import scipy.linalg.blas
113+
from scipy.linalg.blas import get_blas_funcs
114114

115115
have_fblas = True
116-
try:
117-
fblas = scipy.linalg.blas.fblas
118-
except AttributeError:
119-
# A change merged in Scipy development version on 2012-12-02 replaced
120-
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
121-
# See http://github.com/scipy/scipy/pull/358
122-
fblas = scipy.linalg.blas
123-
_blas_gemv_fns = {
124-
np.dtype("float32"): fblas.sgemv,
125-
np.dtype("float64"): fblas.dgemv,
126-
np.dtype("complex64"): fblas.cgemv,
127-
np.dtype("complex128"): fblas.zgemv,
128-
}
129116
except ImportError as e:
130117
have_fblas = False
131118
# This is used in Gemv and ScipyGer. We use CGemv and CGer
@@ -146,18 +133,18 @@ def check_init_y():
146133
if check_init_y._result is None:
147134
if not have_fblas:
148135
check_init_y._result = False
149-
150-
y = float("NaN") * np.ones((2,))
151-
x = np.ones((2,))
152-
A = np.ones((2, 2))
153-
gemv = _blas_gemv_fns[y.dtype]
154-
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
155-
check_init_y._result = np.isnan(y).any()
136+
else:
137+
y = float("NaN") * np.ones((2,))
138+
x = np.ones((2,))
139+
A = np.ones((2, 2))
140+
gemv = get_blas_funcs(names="gemv", dtype=y.dtype)
141+
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
142+
check_init_y._result = np.isnan(y).any()
156143

157144
return check_init_y._result
158145

159146

160-
check_init_y._result = None
147+
check_init_y._result = None # type: ignore
161148

162149

163150
class Gemv(Op):
@@ -210,14 +197,11 @@ def make_node(self, y, alpha, A, x, beta):
210197

211198
def perform(self, node, inputs, out_storage):
212199
y, alpha, A, x, beta = inputs
213-
if (
214-
have_fblas
215-
and y.shape[0] != 0
216-
and x.shape[0] != 0
217-
and y.dtype in _blas_gemv_fns
218-
):
219-
gemv = _blas_gemv_fns[y.dtype]
220-
200+
try:
201+
gemv = get_blas_funcs(names="gemv", dtype=y.dtype)
202+
except Exception:
203+
gemv = None
204+
if have_fblas and y.shape[0] != 0 and x.shape[0] != 0 and gemv is not None:
221205
if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]:
222206
raise ValueError(
223207
"Incompatible shapes for gemv "

pytensor/tensor/blas_scipy.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,22 @@
22
Implementations of BLAS Ops based on scipy's BLAS bindings.
33
"""
44

5-
import numpy as np
65

76
from pytensor.tensor.blas import Ger, have_fblas
87

98

109
if have_fblas:
11-
from pytensor.tensor.blas import fblas
12-
13-
_blas_ger_fns = {
14-
np.dtype("float32"): fblas.sger,
15-
np.dtype("float64"): fblas.dger,
16-
np.dtype("complex64"): fblas.cgeru,
17-
np.dtype("complex128"): fblas.zgeru,
18-
}
10+
from scipy.linalg.blas import get_blas_funcs
1911

2012

2113
class ScipyGer(Ger):
22-
def prepare_node(self, node, storage_map, compute_map, impl):
23-
if impl == "py":
24-
node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)]
25-
2614
def perform(self, node, inputs, output_storage):
2715
cA, calpha, cx, cy = inputs
2816
(cZ,) = output_storage
2917
# N.B. some versions of scipy (e.g. mine) don't actually work
3018
# in-place on a, even when I tell it to.
3119
A = cA
32-
local_ger = node.tag.local_ger
20+
local_ger = get_blas_funcs(names="ger", dtype=cA.dtype)
3321
if A.size == 0:
3422
# We don't have to compute anything, A is empty.
3523
# We need this special case because Numpy considers it

scripts/mypy-failing.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ pytensor/scalar/basic.py
1717
pytensor/sparse/basic.py
1818
pytensor/sparse/type.py
1919
pytensor/tensor/basic.py
20-
pytensor/tensor/blas.py
2120
pytensor/tensor/blas_c.py
2221
pytensor/tensor/blas_headers.py
2322
pytensor/tensor/elemwise.py
@@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py
3130
pytensor/tensor/subtensor.py
3231
pytensor/tensor/type.py
3332
pytensor/tensor/type_other.py
34-
pytensor/tensor/variable.py
33+
pytensor/tensor/variable.py

0 commit comments

Comments
 (0)