From f9e48e033e5e9ecff6b97206b5f7dc145a0bbb53 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Apr 2023 11:34:18 +0200 Subject: [PATCH] Raise warning if RVs are present in derived probability graphs --- pymc/logprob/basic.py | 54 ++++++++++++++++++++++++++++++----- tests/logprob/test_basic.py | 57 +++++++++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index d0fc6bd3ae..bb2126e33e 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -65,7 +65,33 @@ TensorLike: TypeAlias = Union[Variable, float, np.ndarray] -def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: +def _warn_rvs_in_inferred_graph(graph: Sequence[TensorVariable]): + """Issue warning if any RVs are found in graph. + + RVs are usually an (implicit) conditional input of the derived probability expression, + and meant to be replaced by respective value variables before evaluation. + However, when the IR graph is built, any non-input nodes (including RVs) are cloned, + breaking the link with the original ones. + This makes it impossible (or difficult) to replace it by the respective values afterward, + so we instruct users to do it beforehand. + """ + from pymc.testing import assert_no_rvs + + try: + assert_no_rvs(graph) + except AssertionError: + warnings.warn( + "RandomVariables were found in the derived graph. " + "These variables are a clone and do not match the original ones on identity.\n" + "If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: " + "`logp(model.replace_rvs_by_values([rv])[0], value)`", + stacklevel=3, + ) + + +def logp( + rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs +) -> TensorVariable: """Return the log-probability graph of a Random Variable""" value = pt.as_tensor_variable(value, dtype=rv.dtype) @@ -74,10 +100,15 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: except NotImplementedError: fgraph, _, _ = construct_ir_fgraph({rv: value}) [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() - return _logprob_helper(ir_rv, ir_value, **kwargs) + expr = _logprob_helper(ir_rv, ir_value, **kwargs) + if warn_missing_rvs: + _warn_rvs_in_inferred_graph(expr) + return expr -def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: +def logcdf( + rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs +) -> TensorVariable: """Create a graph for the log-CDF of a Random Variable.""" value = pt.as_tensor_variable(value, dtype=rv.dtype) try: @@ -86,10 +117,15 @@ def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: # Try to rewrite rv fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) [ir_rv] = fgraph.outputs - return _logcdf_helper(ir_rv, value, **kwargs) + expr = _logcdf_helper(ir_rv, value, **kwargs) + if warn_missing_rvs: + _warn_rvs_in_inferred_graph(expr) + return expr -def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: +def icdf( + rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs +) -> TensorVariable: """Create a graph for the inverse CDF of a Random Variable.""" value = pt.as_tensor_variable(value, dtype=rv.dtype) try: @@ -98,7 +134,10 @@ def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: # Try to rewrite rv fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) [ir_rv] = fgraph.outputs - return _icdf_helper(ir_rv, value, **kwargs) + expr = _icdf_helper(ir_rv, value, **kwargs) + if warn_missing_rvs: + _warn_rvs_in_inferred_graph(expr) + return expr def factorized_joint_logprob( @@ -215,7 +254,8 @@ def factorized_joint_logprob( if warn_missing_rvs: warnings.warn( "Found a random variable that was neither among the observations " - f"nor the conditioned variables: {node.outputs}" + f"nor the conditioned variables: {outputs}.\n" + "This variables is a clone and does not match the original one on identity." ) continue diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 0546396dec..92e4df7dff 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -56,7 +56,9 @@ import pymc as pm from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp +from pymc.logprob.transforms import LogTransform from pymc.logprob.utils import rvs_to_value_vars, walk_model +from pymc.pytensorf import replace_rvs_by_values from pymc.testing import assert_no_rvs from tests.logprob.utils import joint_logprob @@ -248,16 +250,25 @@ def test_persist_inputs(): y_vv_2 = y_vv * 2 logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + assert y_vv in ancestors([logp_2]) + assert y_vv_2 in ancestors([logp_2]) + + # Even when they are random + y_vv = pt.random.normal(name="y_vv2") + y_vv_2 = y_vv * 2 + logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + + assert y_vv in ancestors([logp_2]) assert y_vv_2 in ancestors([logp_2]) -def test_warn_random_not_found(): +def test_warn_random_found_factorized_joint_logprob(): x_rv = pt.random.normal(name="x") y_rv = pt.random.normal(x_rv, 1, name="y") y_vv = y_rv.clone() - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="Found a random variable that was neither among"): factorized_joint_logprob({y_rv: y_vv}) with warnings.catch_warnings(): @@ -457,3 +468,45 @@ def test_probability_inference_fails(func, func_name): match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}", ): func(pt.cos(pm.Normal.dist()), 1) + + +@pytest.mark.parametrize( + "func, scipy_func, test_value", + [ + (logp, "logpdf", 5.0), + (logcdf, "logcdf", 5.0), + (icdf, "ppf", 0.7), + ], +) +def test_warn_random_found_probability_inference(func, scipy_func, test_value): + # Fail if unexpected warning is issued + with warnings.catch_warnings(): + warnings.simplefilter("error") + + input_rv = pm.Normal.dist(0, name="input") + # Note: This graph could correspond to a convolution of two normals + # In which case the inference should either return that or fail explicitly + # For now, the lopgrob submodule treats the input as a stochastic value. + rv = pt.exp(pm.Normal.dist(input_rv)) + with pytest.warns(UserWarning, match="RandomVariables were found in the derived graph"): + assert func(rv, 0.0) + + res = func(rv, 0.0, warn_missing_rvs=False) + # This is the problem we are warning about, as now we can no longer identify the original rv in the graph + # or replace it by the respective value + assert rv not in ancestors([res]) + + # Test that the prescribed solution does not raise a warning and works as expected + input_vv = input_rv.clone() + [new_rv] = replace_rvs_by_values( + [rv], + rvs_to_values={input_rv: input_vv}, + rvs_to_transforms={input_rv: LogTransform()}, + ) + input_vv_test = 1.3 + np.testing.assert_almost_equal( + func(new_rv, test_value).eval({input_vv: input_vv_test}), + getattr(sp.lognorm(s=1, loc=0, scale=np.exp(np.exp(input_vv_test))), scipy_func)( + test_value + ), + )