diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 9abc480286..3f904a1175 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -144,20 +144,20 @@ # If check_init_y() == True we need to initialize y when beta == 0. def check_init_y(): if check_init_y._result is None: - if not have_fblas: + if not have_fblas: # pragma: no cover check_init_y._result = False - - y = float("NaN") * np.ones((2,)) - x = np.ones((2,)) - A = np.ones((2, 2)) - gemv = _blas_gemv_fns[y.dtype] - gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) - check_init_y._result = np.isnan(y).any() + else: + y = float("NaN") * np.ones((2,)) + x = np.ones((2,)) + A = np.ones((2, 2)) + gemv = _blas_gemv_fns[y.dtype] + gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) + check_init_y._result = np.isnan(y).any() return check_init_y._result -check_init_y._result = None +check_init_y._result = None # type: ignore class Gemv(Op): diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py index 527d5150a1..16fb90988b 100644 --- a/pytensor/tensor/blas_scipy.py +++ b/pytensor/tensor/blas_scipy.py @@ -19,17 +19,13 @@ class ScipyGer(Ger): - def prepare_node(self, node, storage_map, compute_map, impl): - if impl == "py": - node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)] - def perform(self, node, inputs, output_storage): cA, calpha, cx, cy = inputs (cZ,) = output_storage # N.B. some versions of scipy (e.g. mine) don't actually work # in-place on a, even when I tell it to. A = cA - local_ger = node.tag.local_ger + local_ger = _blas_ger_fns[cA.dtype] if A.size == 0: # We don't have to compute anything, A is empty. # We need this special case because Numpy considers it diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 4b32536bec..7d9867113e 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -17,7 +17,6 @@ pytensor/scalar/basic.py pytensor/sparse/basic.py pytensor/sparse/type.py pytensor/tensor/basic.py -pytensor/tensor/blas.py pytensor/tensor/blas_c.py pytensor/tensor/blas_headers.py pytensor/tensor/elemwise.py @@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py pytensor/tensor/type.py pytensor/tensor/type_other.py -pytensor/tensor/variable.py \ No newline at end of file +pytensor/tensor/variable.py diff --git a/tests/tensor/test_blas_scipy.py b/tests/tensor/test_blas_scipy.py index e65e7d90c2..7cdfaadc34 100644 --- a/tests/tensor/test_blas_scipy.py +++ b/tests/tensor/test_blas_scipy.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -58,6 +60,17 @@ def test_scaled_A_plus_scaled_outer(self): self.assertFunctionContains(f, gemm_no_inplace) self.run_f(f) # DebugMode tests correctness + def test_pickle(self): + out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y) + f = pytensor.function([self.A, self.a, self.x, self.y], out) + new_f = pickle.loads(pickle.dumps(f)) + + assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer) + assert np.allclose( + f(self.Aval, 1.0, self.xval, self.yval), + new_f(self.Aval, 1.0, self.xval, self.yval), + ) + class TestBlasStridesScipy(TestBlasStrides): mode = pytensor.compile.get_default_mode()