Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,14 @@ def _random(*args, **kwargs):
clsdict["random"] = _random

rv_op = clsdict.setdefault("rv_op", None)
rv_type = None
rv_type = clsdict.setdefault("rv_type", None)

if isinstance(rv_op, RandomVariable):
rv_type = type(rv_op)
clsdict["rv_type"] = rv_type
if rv_type is not None:
assert isinstance(rv_op, rv_type)
else:
rv_type = type(rv_op)
clsdict["rv_type"] = rv_type

new_cls = super().__new__(cls, name, bases, clsdict)

Expand Down Expand Up @@ -155,8 +158,8 @@ def icdf(op, value, *dist_params, **kwargs):
def moment(op, rv, rng, size, dtype, *dist_params):
return class_moment(rv, size, *dist_params)

# Register the PyTensor `RandomVariable` type as a subclass of this
# `Distribution` type.
# Register the PyTensor rv_type as a subclass of this
# PyMC Distribution type.
new_cls.register(rv_type)

return new_cls
Expand Down
7 changes: 7 additions & 0 deletions pymc/tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pymc as pm

from pymc.distributions import (
Censored,
DiracDelta,
Flat,
HalfNormal,
Expand Down Expand Up @@ -569,3 +570,9 @@ def test_tag_future_warning_dist():
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
value_var = new_x.tag.value_var
assert value_var == "1"


def test_distribution_op_registered():
"""Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions."""
assert isinstance(Normal.dist().owner.op, Normal)
assert isinstance(Censored.dist(Normal.dist(), lower=None, upper=None).owner.op, Censored)