Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,20 @@
# If check_init_y() == True we need to initialize y when beta == 0.
def check_init_y():
if check_init_y._result is None:
if not have_fblas:
if not have_fblas: # pragma: no cover
check_init_y._result = False

y = float("NaN") * np.ones((2,))
x = np.ones((2,))
A = np.ones((2, 2))
gemv = _blas_gemv_fns[y.dtype]
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any()
else:
y = float("NaN") * np.ones((2,))
x = np.ones((2,))
A = np.ones((2, 2))
gemv = _blas_gemv_fns[y.dtype]
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any()

return check_init_y._result


check_init_y._result = None
check_init_y._result = None # type: ignore


class Gemv(Op):
Expand Down
6 changes: 1 addition & 5 deletions pytensor/tensor/blas_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@


class ScipyGer(Ger):
def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)]

def perform(self, node, inputs, output_storage):
cA, calpha, cx, cy = inputs
(cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = node.tag.local_ger
local_ger = _blas_ger_fns[cA.dtype]
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
Expand Down
3 changes: 1 addition & 2 deletions scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pytensor/scalar/basic.py
pytensor/sparse/basic.py
pytensor/sparse/type.py
pytensor/tensor/basic.py
pytensor/tensor/blas.py
pytensor/tensor/blas_c.py
pytensor/tensor/blas_headers.py
pytensor/tensor/elemwise.py
Expand All @@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py
pytensor/tensor/type_other.py
pytensor/tensor/variable.py
pytensor/tensor/variable.py
13 changes: 13 additions & 0 deletions tests/tensor/test_blas_scipy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import numpy as np
import pytest

Expand Down Expand Up @@ -58,6 +60,17 @@ def test_scaled_A_plus_scaled_outer(self):
self.assertFunctionContains(f, gemm_no_inplace)
self.run_f(f) # DebugMode tests correctness

def test_pickle(self):
out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y)
f = pytensor.function([self.A, self.a, self.x, self.y], out)
new_f = pickle.loads(pickle.dumps(f))

assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer)
assert np.allclose(
f(self.Aval, 1.0, self.xval, self.yval),
new_f(self.Aval, 1.0, self.xval, self.yval),
)


class TestBlasStridesScipy(TestBlasStrides):
mode = pytensor.compile.get_default_mode()
Expand Down