diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 3ba6b9c1f5..ec9b8970d0 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -47,7 +47,7 @@ node_rewriter, pre_greedy_node_rewriter, ) -from pytensor.ifelse import ifelse +from pytensor.ifelse import IfElse, ifelse from pytensor.scalar.basic import Switch from pytensor.tensor.basic import Join, MakeVector from pytensor.tensor.elemwise import Elemwise @@ -73,10 +73,11 @@ from pymc.logprob.rewriting import ( local_lift_DiracDelta, logprob_rewrites_db, + measurable_ir_rewrites_db, subtensor_ops, ) from pymc.logprob.tensor import naive_bcast_rv_lift -from pymc.logprob.utils import ignore_logprob +from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars def is_newaxis(x): @@ -483,3 +484,71 @@ def logprob_MixtureRV( "basic", "mixture", ) + + +class MeasurableIfElse(IfElse): + """Measurable subclass of IfElse operator.""" + + +MeasurableVariable.register(MeasurableIfElse) + + +@node_rewriter([IfElse]) +def find_measurable_ifelse_mixture(fgraph, node): + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableIfElse): + return None + + # Check if all components are unvalued measuarable variables + if_var, *base_rvs = node.inputs + + if not all( + ( + rv.owner is not None + and isinstance(rv.owner.op, MeasurableVariable) + and rv not in rv_map_feature.rv_values + ) + for rv in base_rvs + ): + return None # pragma: no cover + + unmeasurable_base_rvs = ignore_logprob_multiple_vars(base_rvs, rv_map_feature.rv_values) + + return MeasurableIfElse(n_outs=node.op.n_outs).make_node(if_var, *unmeasurable_base_rvs).outputs + + +measurable_ir_rewrites_db.register( + "find_measurable_ifelse_mixture", + find_measurable_ifelse_mixture, + "basic", + "mixture", +) + + +@_logprob.register(MeasurableIfElse) +def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs): + """Compute the log-likelihood graph for an `IfElse`.""" + from pymc.pytensorf import replace_rvs_by_values + + assert len(values) * 2 == len(base_rvs) + + rvs_to_values_then = {then_rv: value for then_rv, value in zip(base_rvs[: len(values)], values)} + rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)} + + logps_then = [ + logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items() + ] + logps_else = [ + logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items() + ] + + # If the multiple variables depend on each other, we have to replace them + # by the respective values + logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then) + logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else) + + return ifelse(if_var, logps_then, logps_else) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 92c07c3fcf..4e587c2f70 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -40,7 +40,9 @@ import pytest import scipy.stats.distributions as sp +from pytensor import function from pytensor.graph.basic import Variable, equal_computations +from pytensor.ifelse import ifelse from pytensor.tensor.random.basic import CategoricalRV from pytensor.tensor.shape import shape_tuple from pytensor.tensor.subtensor import as_index_constant @@ -942,3 +944,109 @@ def test_switch_mixture(): np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0})) np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0})) + + +def test_ifelse_mixture_one_component(): + if_rv = pt.random.bernoulli(0.5, name="if") + scale_rv = pt.random.halfnormal(name="scale") + comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then") + comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else") + mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix") + + if_vv = if_rv.clone() + scale_vv = scale_rv.clone() + mix_vv = mix_rv.clone() + mix_logp = factorized_joint_logprob({if_rv: if_vv, scale_rv: scale_vv, mix_rv: mix_vv})[mix_vv] + assert_no_rvs(mix_logp) + + fn = function([if_vv, scale_vv, mix_vv], mix_logp) + scale_vv_test = 0.75 + mix_vv_test = np.r_[1.0, 2.5] + np.testing.assert_array_almost_equal( + fn(1, scale_vv_test, mix_vv_test), + sp.norm(0, scale_vv_test).logpdf(mix_vv_test), + ) + mix_vv_test = np.r_[1.0, 2.5, 3.5, 4.0] + np.testing.assert_array_almost_equal( + fn(0, scale_vv_test, mix_vv_test), sp.halfnorm(0, scale_vv_test).logpdf(mix_vv_test) + ) + + +def test_ifelse_mixture_multiple_components(): + rng = np.random.default_rng(968) + + if_var = pt.scalar("if_var", dtype="bool") + comp_then1 = pt.random.normal(size=(2,), name="comp_true1") + comp_then2 = pt.random.normal(comp_then1, size=(2, 2), name="comp_then2") + comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1") + comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2") + + mix_rv1, mix_rv2 = ifelse( + if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix" + ) + mix_vv1 = mix_rv1.clone() + mix_vv2 = mix_rv2.clone() + mix_logp1, mix_logp2 = factorized_joint_logprob({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values() + assert_no_rvs(mix_logp1) + assert_no_rvs(mix_logp2) + + fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum()) + mix_vv1_test = np.abs(rng.normal(size=(2,))) + mix_vv2_test = np.abs(rng.normal(size=(2, 2))) + np.testing.assert_almost_equal( + fn(True, mix_vv1_test, mix_vv2_test), + sp.norm(0, 1).logpdf(mix_vv1_test).sum() + + sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(), + ) + mix_vv1_test = np.abs(rng.normal(size=(4,))) + mix_vv2_test = np.abs(rng.normal(size=(4, 4))) + np.testing.assert_almost_equal( + fn(False, mix_vv1_test, mix_vv2_test), + sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(), + ) + + +def test_ifelse_mixture_shared_component(): + rng = np.random.default_rng(1009) + + if_var = pt.scalar("if_var", dtype="bool") + outer_rv = pt.random.normal(name="outer") + # comp_shared need not be an output of ifelse at all, + # but since we allow arbitrary graphs we test it works as expected. + comp_shared = pt.random.normal(size=(2,), name="comp_shared") + comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2), name="comp_then") + comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2), name="comp_else") + shared_rv, mix_rv = ifelse( + if_var, [comp_shared, comp_then], [comp_shared, comp_else], name="mix" + ) + + outer_vv = outer_rv.clone() + shared_vv = shared_rv.clone() + mix_vv = mix_rv.clone() + outer_logp, mix_logp1, mix_logp2 = factorized_joint_logprob( + {outer_rv: outer_vv, shared_rv: shared_vv, mix_rv: mix_vv} + ).values() + assert_no_rvs(outer_logp) + assert_no_rvs(mix_logp1) + assert_no_rvs(mix_logp2) + + fn = function([if_var, outer_vv, shared_vv, mix_vv], mix_logp1.sum() + mix_logp2.sum()) + outer_vv_test = rng.normal() + shared_vv_test = rng.normal(size=(2,)) + mix_vv_test = rng.normal(size=(4, 2)) + np.testing.assert_almost_equal( + fn(True, outer_vv_test, shared_vv_test, mix_vv_test), + ( + sp.norm(0, 1).logpdf(shared_vv_test).sum() + + sp.norm(outer_vv_test + shared_vv_test, 1).logpdf(mix_vv_test).sum() + ), + ) + mix_vv_test = rng.normal(size=(8, 2)) + np.testing.assert_almost_equal( + fn(False, outer_vv_test, shared_vv_test, mix_vv_test), + ( + sp.norm(0, 1).logpdf(shared_vv_test).sum() + + sp.norm(outer_vv_test + shared_vv_test, 10).logpdf(mix_vv_test).sum() + ), + decimal=6, + )