From 9c670b11a65f9a338c35bf92e15f1ac9ea58780e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 8 Mar 2023 15:06:22 +0100 Subject: [PATCH 1/3] Use better `assert_no_rvs` from logprob submodule This utility can find RVs in inner graphs --- pymc/testing.py | 22 ++++++++++++++++++---- tests/logprob/utils.py | 28 +--------------------------- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index a0747bdaca..ea3ccfc46f 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -24,7 +24,8 @@ from numpy import random as nr from numpy import testing as npt from pytensor.compile.mode import Mode -from pytensor.graph.basic import ancestors +from pytensor.graph.basic import walk +from pytensor.graph.op import HasInnerGraph from pytensor.graph.rewriting.basic import in2out from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable @@ -37,7 +38,7 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.logprob import joint_logp -from pymc.logprob.abstract import icdf +from pymc.logprob.abstract import MeasurableVariable, icdf from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( compile_pymc, @@ -958,5 +959,18 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable: def assert_no_rvs(var): - assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner) - return var + """Assert that there are no `MeasurableVariable` nodes in a graph.""" + + def expand(r): + owner = r.owner + if owner: + inputs = list(reversed(owner.inputs)) + + if isinstance(owner.op, HasInnerGraph): + inputs += owner.op.inner_outputs + + return inputs + + for v in walk([var], expand, False): + if v.owner and isinstance(v.owner.op, (RandomVariable, MeasurableVariable)): + raise AssertionError(f"RV found in graph: {v}") diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 644d2a83db..5a1c9c1656 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -39,19 +39,11 @@ import numpy as np from pytensor import tensor as pt -from pytensor.graph.basic import walk -from pytensor.graph.op import HasInnerGraph from pytensor.tensor.var import TensorVariable from scipy import stats as stats from pymc.logprob import factorized_joint_logprob -from pymc.logprob.abstract import ( - MeasurableVariable, - get_measurable_outputs, - icdf, - logcdf, - logprob, -) +from pymc.logprob.abstract import get_measurable_outputs, icdf, logcdf, logprob from pymc.logprob.utils import ignore_logprob @@ -82,24 +74,6 @@ def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable] return pt.add(*logprob.values()) -def assert_no_rvs(var): - """Assert that there are no `MeasurableVariable` nodes in a graph.""" - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - for v in walk([var], expand, False): - if v.owner and isinstance(v.owner.op, MeasurableVariable): - raise AssertionError(f"Variable {v} is a MeasurableVariable") - - def simulate_poiszero_hmm( N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], seed=None ): From 98458f880d8a09460452a673b438f5c9076bb30a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 8 Mar 2023 16:03:31 +0100 Subject: [PATCH 2/3] Add test for scans over sequences --- tests/logprob/test_scan.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index f7655a2336..4f52d98002 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -478,3 +478,27 @@ def test_scan_non_pure_rv_output(): grw_logp.eval({grw_vv: grw_vv_test}), stats.norm.logpdf(np.ones(10)), ) + + +def test_scan_over_seqs(): + """Test that logprob inference for scans based on sequences (mapping).""" + rng = np.random.default_rng(543) + n_steps = 10 + + xs = pt.random.normal(size=(n_steps,), name="xs") + ys, _ = pytensor.scan( + fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys" + ) + + xs_vv = ys.clone() + ys_vv = ys.clone() + ys_logp = factorized_joint_logprob({xs: xs_vv, ys: ys_vv})[ys_vv] + + assert_no_rvs(ys_logp) + + xs_test = rng.normal(size=(10,)) + ys_test = rng.normal(size=(10,)) + np.testing.assert_array_almost_equal( + ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}), + stats.norm.logpdf(ys_test, xs_test), + ) From eddcc8f488ab9fa3893d587e435ffd8e6a61c7f9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 8 Mar 2023 16:07:12 +0100 Subject: [PATCH 3/3] Fix logprob inference for scans with carried deterministic states --- pymc/logprob/scan.py | 68 +++++++++++++++++++++++++------------- tests/logprob/test_scan.py | 52 +++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 23 deletions(-) diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 151f146a4f..2298b9d038 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -42,7 +42,6 @@ import pytensor.tensor as pt from pytensor.graph.basic import Variable -from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -50,7 +49,6 @@ from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2 from pytensor.scan.utils import ScanArgs from pytensor.tensor.random.type import RandomType -from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor from pytensor.tensor.var import TensorVariable from pytensor.updates import OrderedUpdates @@ -63,11 +61,12 @@ ) from pymc.logprob.joint_logprob import factorized_joint_logprob from pymc.logprob.rewriting import ( - PreserveRVMappings, + construct_ir_fgraph, inc_subtensor_ops, logprob_rewrites_db, measurable_ir_rewrites_db, ) +from pymc.pytensorf import replace_rvs_by_values class MeasurableScan(Scan): @@ -249,9 +248,27 @@ def remove(x, i): new_inner_out_nit_sot = tuple(output_scan_args.inner_out_nit_sot) + tuple( inner_out_fn(remapped_io_to_ii) ) - output_scan_args.inner_out_nit_sot = list(new_inner_out_nit_sot) + # Finally, we need to replace any lingering references to the new + # internal variables that could be in the recurrent states needed + # to compute the new nit_sots + traced_outs = ( + output_scan_args.inner_out_mit_sot + + output_scan_args.inner_out_sit_sot + + output_scan_args.inner_out_nit_sot + ) + traced_outs = replace_rvs_by_values(traced_outs, rvs_to_values=remapped_io_to_ii) + # Update output mappings + n_mit_sot = len(output_scan_args.inner_out_mit_sot) + output_scan_args.inner_out_mit_sot = traced_outs[:n_mit_sot] + offset = n_mit_sot + n_sit_sot = len(output_scan_args.inner_out_sit_sot) + output_scan_args.inner_out_sit_sot = traced_outs[offset : offset + n_sit_sot] + offset += n_sit_sot + n_nit_sot = len(output_scan_args.inner_out_nit_sot) + output_scan_args.inner_out_nit_sot = traced_outs[offset : offset + n_nit_sot] + return output_scan_args @@ -331,7 +348,10 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te for key, value in updates.items(): key.default_update = value - return logp_scan_out + # Return only the logp outputs, not any potentially carried states + logp_outputs = logp_scan_out[-len(values) :] + + return logp_outputs @node_rewriter([Scan]) @@ -504,19 +524,9 @@ def add_opts_to_inner_graphs(fgraph, node): if getattr(node.op.mode, "had_logprob_rewrites", False): return None - inner_fgraph = FunctionGraph( - node.op.inner_inputs, - node.op.inner_outputs, - clone=True, - copy_inputs=False, - copy_orphans=False, - features=[ - ShapeFeature(), - PreserveRVMappings({}), - ], - ) - - logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph) + inner_rv_values = {out: out.type() for out in node.op.inner_outputs} + ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) + inner_fgraph, rv_values, _ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter) new_outputs = list(inner_fgraph.outputs) @@ -531,11 +541,23 @@ def add_opts_to_inner_graphs(fgraph, node): @_get_measurable_outputs.register(MeasurableScan) -def _get_measurable_outputs_MeasurableScan(op, node): - # TODO: This should probably use `get_random_outer_outputs` - # scan_args = ScanArgs.from_node(node) - # rv_outer_outs = get_random_outer_outputs(scan_args) - return [o for o in node.outputs if not isinstance(o.type, RandomType)] +def _get_measurable_outputs_MeasurableScan(op: Scan, node): + """Collect measurable outputs for Measurable Scans""" + inner_out_from_outer_out_map = op.get_oinp_iinp_iout_oout_mappings()["inner_out_from_outer_out"] + inner_outs = op.inner_outputs + + # Measurable scan outputs are those whose inner scan output counterparts are also measurable + measurable_outputs = [] + for out_idx, out in enumerate(node.outputs): + [inner_out_idx] = inner_out_from_outer_out_map[out_idx] + inner_out = inner_outs[inner_out_idx] + inner_out_node = inner_out.owner + if isinstance( + inner_out_node.op, MeasurableVariable + ) and inner_out in get_measurable_outputs(inner_out_node.op, inner_out_node): + measurable_outputs.append(out) + + return measurable_outputs measurable_ir_rewrites_db.register( diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 4f52d98002..7259fa80c8 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -502,3 +502,55 @@ def test_scan_over_seqs(): ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}), stats.norm.logpdf(ys_test, xs_test), ) + + +def test_scan_carried_deterministic_state(): + """Test logp of scans with carried states downstream of measured variables. + + A moving average model with 2 lags is used for testing. + """ + rng = np.random.default_rng(490) + steps = 99 + + rho = pt.vector("rho", shape=(2,)) + sigma = pt.scalar("sigma") + + def ma2_step(eps_tm2, eps_tm1, rho, sigma): + mu = eps_tm1 * rho[0] + eps_tm2 * rho[1] + y = pt.random.normal(mu, sigma) + eps = y - mu + update = {y.owner.inputs[0]: y.owner.outputs[0]} + return (eps, y), update + + [_, ma2], ma2_updates = pytensor.scan( + fn=ma2_step, + outputs_info=[{"initial": pt.arange(2, dtype="float64"), "taps": range(-2, 0)}, None], + non_sequences=[rho, sigma], + n_steps=steps, + strict=True, + name="ma2", + ) + + def ref_logp(values, rho, sigma): + epsilon_tm2 = 0 + epsilon_tm1 = 1 + step_logps = np.zeros_like(values) + for t, value in enumerate(values): + mu = epsilon_tm1 * rho[0] + epsilon_tm2 * rho[1] + step_logps[t] = stats.norm.logpdf(value, mu, sigma) + epsilon_tm2 = epsilon_tm1 + epsilon_tm1 = value - mu + return step_logps + + ma2_vv = ma2.clone() + logp_expr = logp(ma2, ma2_vv) + assert_no_rvs(logp_expr) + + ma2_test = rng.normal(size=(steps,)) + rho_test = np.array([0.3, 0.7]) + sigma_test = 0.9 + + np.testing.assert_array_almost_equal( + logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}), + ref_logp(ma2_test, rho_test, sigma_test), + )