diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 50b0ae5de9..78ebf41cae 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -174,8 +174,9 @@ def _str_repr(self, name=None, dist=None, formatting='plain'): return "{var_name} ~ {distr_name}({params})".format(var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string) - def __str__(self, **kwargs): - return self._str_repr(formatting="plain", **kwargs) + # TODO: Enable this one we figure out pickle + # def __str__(self, **kwargs): + # return self._str_repr(formatting="plain", **kwargs) def _repr_latex_(self, **kwargs): """Magic method name for IPython to use for LaTeX formatting.""" diff --git a/pymc3/model.py b/pymc3/model.py index f7a504d78e..85c549ba04 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -80,8 +80,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"): def _repr_latex_(self, **kwargs): return self._str_repr(formatting="latex", **kwargs) - def __str__(self, **kwargs): - return self._str_repr(formatting="plain", **kwargs) + # TODO: Enable this one we figure out pickle + # def __str__(self, **kwargs): + # return self._str_repr(formatting="plain", **kwargs) __latex__ = _repr_latex_ @@ -1423,8 +1424,9 @@ def _str_repr(self, formatting="plain", **kwargs): for n, d in zip(names, distrs)] return "\n".join(rv_reprs) - def __str__(self, **kwargs): - return self._str_repr(formatting="plain", **kwargs) + # TODO: Enable this one we figure out pickle + # def __str__(self, **kwargs): + # return self._str_repr(formatting="plain", **kwargs) def _repr_latex_(self, **kwargs): return self._str_repr(formatting="latex", **kwargs) @@ -1934,10 +1936,12 @@ def Deterministic(name, var, model=None, dims=None): # simply assigning var.__str__ is not enough, since str() will default to the class- # defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028 - old_type = type(var) - new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,), - {'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')}) - var.__class__ = new_type + + # TODO: Fix enable this once we figure out pickle + # old_type = type(var) + # new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,), + # {'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')}) + # var.__class__ = new_type return var diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index c997081b04..aff0e1650b 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1705,6 +1705,8 @@ def test_bound(): BoundPoissonPositionalArgs = Bound(Poisson, upper=6)("x", 2.0) +@pytest.mark.xfail(reason=("Currently failing due to pm.Deterministic issue. " + "See https://github.com/pymc-devs/pymc3/pull/4117 ")) class TestStrAndLatexRepr: def setup_class(self): # True parameter values diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 436ce7568e..e9de589cdc 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -16,6 +16,7 @@ import theano import theano.tensor as tt import numpy as np +import pickle import pandas as pd import numpy.testing as npt import unittest @@ -421,3 +422,27 @@ def test_tempered_logp_dlogp(): npt.assert_allclose(func_nograd(x), func(x)[0]) npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0]) + + +def test_model_pickle(tmpdir): + """Tests that PyMC3 models are pickleable""" + with pm.Model() as model: + x = pm.Normal('x') + pm.Normal('y', observed=1) + + file_path = tmpdir.join("model.p") + with open(file_path, 'wb') as buff: + pickle.dump(model, buff) + + +def test_model_pickle_deterministic(tmpdir): + """Tests that PyMC3 models with deterministics are pickleable""" + with pm.Model() as model: + x = pm.Normal('x') + z = pm.Normal("z") + pm.Deterministic("w", x/z) + pm.Normal('y', observed=1) + + file_path = tmpdir.join("model.p") + with open(file_path, 'wb') as buff: + pickle.dump(model, buff) diff --git a/pymc3/util.py b/pymc3/util.py index 4f3cdd88cb..5a6d38a8b5 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -20,7 +20,8 @@ import arviz from numpy import asscalar, ndarray -from theano.tensor import TensorVariable +# TODO: Reimplement after pickle fix +# from theano.tensor import TensorVariable LATEX_ESCAPE_RE = re.compile(r"(%|_|\$|#|&)", re.MULTILINE) @@ -161,10 +162,10 @@ def get_var_name(var): string representations to our pymc3.PyMC3Variables, yet we want to use the plain name as e.g. keys in dicts. """ - if isinstance(var, TensorVariable): - return super(TensorVariable, var).__str__() - else: - return str(var) + # if isinstance(var, TensorVariable): + # return super(TensorVariable, var).__str__() + # else: + return str(var) def update_start_vals(a, b, model):