From 9bbe3d3e6bce57c00f4ba2f94d5d8b3d2137a47c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jun 2023 07:49:59 +0200 Subject: [PATCH] Make tests compatible with latest release of PyTensor --- pymc/model.py | 2 +- tests/distributions/test_multivariate.py | 7 ++++++- tests/logprob/test_basic.py | 2 +- tests/test_pytensorf.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index b6eac9cbae..fd948230b4 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1232,7 +1232,7 @@ def set_data( if isinstance(length_tensor_origin, TensorConstant): raise ShapeError( f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, " - f"because the dimension length is tied to a {length_tensor_origin}. " + f"because the dimension length is tied to a TensorConstant. " f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, " f"for example by another model variable.", actual=new_length, diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 671fe4fa79..640d7484c9 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -316,7 +316,12 @@ def test_mvnormal_indef(self): f_logp(cov_val, np.ones(2)) dlogp = pt.grad(mvn_logp, cov) f_dlogp = pytensor.function([cov, x], dlogp) - assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2)))) + try: + res = f_dlogp(cov_val, np.ones(2)) + except ValueError: + pass # Op raises internally + else: + assert not np.all(np.isfinite(res)) # Otherwise, should return nan def test_mvnormal_init_fail(self): with pm.Model(): diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 7e55f7a00f..456a8f277a 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -426,7 +426,7 @@ def test_probability_inference(func, scipy_func, test_value): def test_probability_inference_fails(func, func_name): with pytest.raises( NotImplementedError, - match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}", + match=f"{func_name} method not implemented for (Elemwise{{cos,no_inplace}}|Cos)", ): func(pt.cos(pm.Normal.dist()), 1) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index dc1852966d..865047a8d7 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -575,7 +575,7 @@ def step_wo_update(x, rng): with pytest.raises( ValueError, - match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}", + match="No update found for at least one RNG used in Scan Op", ): collect_default_updates([xs])