diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6e4599c985..8b521d0abc 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -112,11 +112,14 @@ def _random(*args, **kwargs): clsdict["random"] = _random rv_op = clsdict.setdefault("rv_op", None) - rv_type = None + rv_type = clsdict.setdefault("rv_type", None) if isinstance(rv_op, RandomVariable): - rv_type = type(rv_op) - clsdict["rv_type"] = rv_type + if rv_type is not None: + assert isinstance(rv_op, rv_type) + else: + rv_type = type(rv_op) + clsdict["rv_type"] = rv_type new_cls = super().__new__(cls, name, bases, clsdict) @@ -155,8 +158,8 @@ def icdf(op, value, *dist_params, **kwargs): def moment(op, rv, rng, size, dtype, *dist_params): return class_moment(rv, size, *dist_params) - # Register the PyTensor `RandomVariable` type as a subclass of this - # `Distribution` type. + # Register the PyTensor rv_type as a subclass of this + # PyMC Distribution type. new_cls.register(rv_type) return new_cls diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index ae7d625ac3..1bcec9c0f5 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -26,6 +26,7 @@ import pymc as pm from pymc.distributions import ( + Censored, DiracDelta, Flat, HalfNormal, @@ -569,3 +570,9 @@ def test_tag_future_warning_dist(): with pytest.warns(FutureWarning, match="Use model.rvs_to_values"): value_var = new_x.tag.value_var assert value_var == "1" + + +def test_distribution_op_registered(): + """Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions.""" + assert isinstance(Normal.dist().owner.op, Normal) + assert isinstance(Censored.dist(Normal.dist(), lower=None, upper=None).owner.op, Censored)