Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
node_rewriter,
pre_greedy_node_rewriter,
)
from pytensor.ifelse import ifelse
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar.basic import Switch
from pytensor.tensor.basic import Join, MakeVector
from pytensor.tensor.elemwise import Elemwise
Expand All @@ -73,10 +73,11 @@
from pymc.logprob.rewriting import (
local_lift_DiracDelta,
logprob_rewrites_db,
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.tensor import naive_bcast_rv_lift
from pymc.logprob.utils import ignore_logprob
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars


def is_newaxis(x):
Expand Down Expand Up @@ -483,3 +484,71 @@ def logprob_MixtureRV(
"basic",
"mixture",
)


class MeasurableIfElse(IfElse):
"""Measurable subclass of IfElse operator."""


MeasurableVariable.register(MeasurableIfElse)


@node_rewriter([IfElse])
def find_measurable_ifelse_mixture(fgraph, node):
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableIfElse):
return None

# Check if all components are unvalued measuarable variables
if_var, *base_rvs = node.inputs

if not all(
(
rv.owner is not None
and isinstance(rv.owner.op, MeasurableVariable)
and rv not in rv_map_feature.rv_values
)
for rv in base_rvs
):
return None # pragma: no cover

unmeasurable_base_rvs = ignore_logprob_multiple_vars(base_rvs, rv_map_feature.rv_values)

return MeasurableIfElse(n_outs=node.op.n_outs).make_node(if_var, *unmeasurable_base_rvs).outputs


measurable_ir_rewrites_db.register(
"find_measurable_ifelse_mixture",
find_measurable_ifelse_mixture,
"basic",
"mixture",
)


@_logprob.register(MeasurableIfElse)
def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for an `IfElse`."""
from pymc.pytensorf import replace_rvs_by_values

assert len(values) * 2 == len(base_rvs)

rvs_to_values_then = {then_rv: value for then_rv, value in zip(base_rvs[: len(values)], values)}
rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)}

logps_then = [
logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What’s the behaviour of logprob when the value variable doesn’t have the expected shape? Will it raise a ValueError? I think that would be bad in this case. Imagine if you have a step method on the condition variable. The stepper might choose to move the condition to an infeasible value and that would kill the sampling process. I would like the condition that doesn’t match shapes to simply return -inf logprob. That way the stepper would discard the proposal and stay in reasonable regions.

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior of the logprob when the value variable doesn't match shape is exactly the same as the logprob of the underlying components. If it doesn't match, it will try to broadcast and fail if it cannot. Other than that we don't use size information in any of the core logprob functions.

pm.logp(pm.Normal.dist(size=(4,)), np.ones((2,))) will be happy to return a logp with two values.

One other case where this shows up is in graphs of the form pt.ones((5,)) + pm.Normal.dist() which we infer to have an equivalent logp as that of pm.Normal.dist(shape=(5,)) even though the generative process contains only one true random variable, and not 5.

I think we need a bigger discussion about the role of shape information in the random and logp graphs, so I wouldn't treat IfElse differently for now.

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue here: #6530

]
logps_else = [
logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
]

# If the multiple variables depend on each other, we have to replace them
# by the respective values
logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then)
logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else)

return ifelse(if_var, logps_then, logps_else)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the goal only to allow for explicit conditions in the mixture instead of marginalising?

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The applications may vary. We don't explicitly marginalize anything in the logprob submodule, but something like MarginalModel would marginalize this just fine if we can give it the logprob function that this PR offers.

108 changes: 108 additions & 0 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import pytest
import scipy.stats.distributions as sp

from pytensor import function
from pytensor.graph.basic import Variable, equal_computations
from pytensor.ifelse import ifelse
from pytensor.tensor.random.basic import CategoricalRV
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.subtensor import as_index_constant
Expand Down Expand Up @@ -942,3 +944,109 @@ def test_switch_mixture():

np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))


def test_ifelse_mixture_one_component():
if_rv = pt.random.bernoulli(0.5, name="if")
scale_rv = pt.random.halfnormal(name="scale")
comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then")
comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else")
mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix")

if_vv = if_rv.clone()
scale_vv = scale_rv.clone()
mix_vv = mix_rv.clone()
mix_logp = factorized_joint_logprob({if_rv: if_vv, scale_rv: scale_vv, mix_rv: mix_vv})[mix_vv]
assert_no_rvs(mix_logp)

fn = function([if_vv, scale_vv, mix_vv], mix_logp)
scale_vv_test = 0.75
mix_vv_test = np.r_[1.0, 2.5]
np.testing.assert_array_almost_equal(
fn(1, scale_vv_test, mix_vv_test),
sp.norm(0, scale_vv_test).logpdf(mix_vv_test),
)
mix_vv_test = np.r_[1.0, 2.5, 3.5, 4.0]
np.testing.assert_array_almost_equal(
fn(0, scale_vv_test, mix_vv_test), sp.halfnorm(0, scale_vv_test).logpdf(mix_vv_test)
)


def test_ifelse_mixture_multiple_components():
rng = np.random.default_rng(968)

if_var = pt.scalar("if_var", dtype="bool")
comp_then1 = pt.random.normal(size=(2,), name="comp_true1")
comp_then2 = pt.random.normal(comp_then1, size=(2, 2), name="comp_then2")
comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1")
comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2")

mix_rv1, mix_rv2 = ifelse(
if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix"
)
mix_vv1 = mix_rv1.clone()
mix_vv2 = mix_rv2.clone()
mix_logp1, mix_logp2 = factorized_joint_logprob({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values()
assert_no_rvs(mix_logp1)
assert_no_rvs(mix_logp2)

fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum())
mix_vv1_test = np.abs(rng.normal(size=(2,)))
mix_vv2_test = np.abs(rng.normal(size=(2, 2)))
np.testing.assert_almost_equal(
fn(True, mix_vv1_test, mix_vv2_test),
sp.norm(0, 1).logpdf(mix_vv1_test).sum()
+ sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(),
)
mix_vv1_test = np.abs(rng.normal(size=(4,)))
mix_vv2_test = np.abs(rng.normal(size=(4, 4)))
np.testing.assert_almost_equal(
fn(False, mix_vv1_test, mix_vv2_test),
sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(),
)


def test_ifelse_mixture_shared_component():
rng = np.random.default_rng(1009)

if_var = pt.scalar("if_var", dtype="bool")
outer_rv = pt.random.normal(name="outer")
# comp_shared need not be an output of ifelse at all,
# but since we allow arbitrary graphs we test it works as expected.
comp_shared = pt.random.normal(size=(2,), name="comp_shared")
comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2), name="comp_then")
comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2), name="comp_else")
shared_rv, mix_rv = ifelse(
if_var, [comp_shared, comp_then], [comp_shared, comp_else], name="mix"
)

outer_vv = outer_rv.clone()
shared_vv = shared_rv.clone()
mix_vv = mix_rv.clone()
outer_logp, mix_logp1, mix_logp2 = factorized_joint_logprob(
{outer_rv: outer_vv, shared_rv: shared_vv, mix_rv: mix_vv}
).values()
assert_no_rvs(outer_logp)
assert_no_rvs(mix_logp1)
assert_no_rvs(mix_logp2)

fn = function([if_var, outer_vv, shared_vv, mix_vv], mix_logp1.sum() + mix_logp2.sum())
outer_vv_test = rng.normal()
shared_vv_test = rng.normal(size=(2,))
mix_vv_test = rng.normal(size=(4, 2))
np.testing.assert_almost_equal(
fn(True, outer_vv_test, shared_vv_test, mix_vv_test),
(
sp.norm(0, 1).logpdf(shared_vv_test).sum()
+ sp.norm(outer_vv_test + shared_vv_test, 1).logpdf(mix_vv_test).sum()
),
)
mix_vv_test = rng.normal(size=(8, 2))
np.testing.assert_almost_equal(
fn(False, outer_vv_test, shared_vv_test, mix_vv_test),
(
sp.norm(0, 1).logpdf(shared_vv_test).sum()
+ sp.norm(outer_vv_test + shared_vv_test, 10).logpdf(mix_vv_test).sum()
),
decimal=6,
)