Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@
import pytensor.tensor as pt

from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2
from pytensor.scan.utils import ScanArgs
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
from pytensor.tensor.var import TensorVariable
from pytensor.updates import OrderedUpdates
Expand All @@ -63,11 +61,12 @@
)
from pymc.logprob.joint_logprob import factorized_joint_logprob
from pymc.logprob.rewriting import (
PreserveRVMappings,
construct_ir_fgraph,
inc_subtensor_ops,
logprob_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.pytensorf import replace_rvs_by_values


class MeasurableScan(Scan):
Expand Down Expand Up @@ -249,9 +248,27 @@ def remove(x, i):
new_inner_out_nit_sot = tuple(output_scan_args.inner_out_nit_sot) + tuple(
inner_out_fn(remapped_io_to_ii)
)

output_scan_args.inner_out_nit_sot = list(new_inner_out_nit_sot)

# Finally, we need to replace any lingering references to the new
# internal variables that could be in the recurrent states needed
# to compute the new nit_sots
traced_outs = (
output_scan_args.inner_out_mit_sot
+ output_scan_args.inner_out_sit_sot
+ output_scan_args.inner_out_nit_sot
)
traced_outs = replace_rvs_by_values(traced_outs, rvs_to_values=remapped_io_to_ii)
# Update output mappings
n_mit_sot = len(output_scan_args.inner_out_mit_sot)
output_scan_args.inner_out_mit_sot = traced_outs[:n_mit_sot]
offset = n_mit_sot
n_sit_sot = len(output_scan_args.inner_out_sit_sot)
output_scan_args.inner_out_sit_sot = traced_outs[offset : offset + n_sit_sot]
offset += n_sit_sot
n_nit_sot = len(output_scan_args.inner_out_nit_sot)
output_scan_args.inner_out_nit_sot = traced_outs[offset : offset + n_nit_sot]

return output_scan_args


Expand Down Expand Up @@ -331,7 +348,10 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te
for key, value in updates.items():
key.default_update = value

return logp_scan_out
# Return only the logp outputs, not any potentially carried states
logp_outputs = logp_scan_out[-len(values) :]

return logp_outputs


@node_rewriter([Scan])
Expand Down Expand Up @@ -504,19 +524,9 @@ def add_opts_to_inner_graphs(fgraph, node):
if getattr(node.op.mode, "had_logprob_rewrites", False):
return None

inner_fgraph = FunctionGraph(
node.op.inner_inputs,
node.op.inner_outputs,
clone=True,
copy_inputs=False,
copy_orphans=False,
features=[
ShapeFeature(),
PreserveRVMappings({}),
],
)

logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph)
inner_rv_values = {out: out.type() for out in node.op.inner_outputs}
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]))
inner_fgraph, rv_values, _ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter)

new_outputs = list(inner_fgraph.outputs)

Expand All @@ -531,11 +541,23 @@ def add_opts_to_inner_graphs(fgraph, node):


@_get_measurable_outputs.register(MeasurableScan)
def _get_measurable_outputs_MeasurableScan(op, node):
# TODO: This should probably use `get_random_outer_outputs`
# scan_args = ScanArgs.from_node(node)
# rv_outer_outs = get_random_outer_outputs(scan_args)
return [o for o in node.outputs if not isinstance(o.type, RandomType)]
def _get_measurable_outputs_MeasurableScan(op: Scan, node):
"""Collect measurable outputs for Measurable Scans"""
inner_out_from_outer_out_map = op.get_oinp_iinp_iout_oout_mappings()["inner_out_from_outer_out"]
inner_outs = op.inner_outputs

# Measurable scan outputs are those whose inner scan output counterparts are also measurable
measurable_outputs = []
for out_idx, out in enumerate(node.outputs):
[inner_out_idx] = inner_out_from_outer_out_map[out_idx]
inner_out = inner_outs[inner_out_idx]
inner_out_node = inner_out.owner
if isinstance(
inner_out_node.op, MeasurableVariable
) and inner_out in get_measurable_outputs(inner_out_node.op, inner_out_node):
measurable_outputs.append(out)

return measurable_outputs


measurable_ir_rewrites_db.register(
Expand Down
22 changes: 18 additions & 4 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
from pytensor.graph.basic import ancestors
from pytensor.graph.basic import walk
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
Expand All @@ -37,7 +38,7 @@
from pymc.distributions.shape_utils import change_dist_size
from pymc.initial_point import make_initial_point_fn
from pymc.logprob import joint_logp
from pymc.logprob.abstract import icdf
from pymc.logprob.abstract import MeasurableVariable, icdf
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
compile_pymc,
Expand Down Expand Up @@ -958,5 +959,18 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:


def assert_no_rvs(var):
assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner)
return var
"""Assert that there are no `MeasurableVariable` nodes in a graph."""

def expand(r):
owner = r.owner
if owner:
inputs = list(reversed(owner.inputs))

if isinstance(owner.op, HasInnerGraph):
inputs += owner.op.inner_outputs

return inputs

for v in walk([var], expand, False):
if v.owner and isinstance(v.owner.op, (RandomVariable, MeasurableVariable)):
raise AssertionError(f"RV found in graph: {v}")
76 changes: 76 additions & 0 deletions tests/logprob/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,79 @@ def test_scan_non_pure_rv_output():
grw_logp.eval({grw_vv: grw_vv_test}),
stats.norm.logpdf(np.ones(10)),
)


def test_scan_over_seqs():
"""Test that logprob inference for scans based on sequences (mapping)."""
rng = np.random.default_rng(543)
n_steps = 10

xs = pt.random.normal(size=(n_steps,), name="xs")
ys, _ = pytensor.scan(
fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys"
)

xs_vv = ys.clone()
ys_vv = ys.clone()
ys_logp = factorized_joint_logprob({xs: xs_vv, ys: ys_vv})[ys_vv]

assert_no_rvs(ys_logp)

xs_test = rng.normal(size=(10,))
ys_test = rng.normal(size=(10,))
np.testing.assert_array_almost_equal(
ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}),
stats.norm.logpdf(ys_test, xs_test),
)


def test_scan_carried_deterministic_state():
"""Test logp of scans with carried states downstream of measured variables.

A moving average model with 2 lags is used for testing.
"""
rng = np.random.default_rng(490)
steps = 99

rho = pt.vector("rho", shape=(2,))
sigma = pt.scalar("sigma")

def ma2_step(eps_tm2, eps_tm1, rho, sigma):
mu = eps_tm1 * rho[0] + eps_tm2 * rho[1]
y = pt.random.normal(mu, sigma)
eps = y - mu
update = {y.owner.inputs[0]: y.owner.outputs[0]}
return (eps, y), update

[_, ma2], ma2_updates = pytensor.scan(
fn=ma2_step,
outputs_info=[{"initial": pt.arange(2, dtype="float64"), "taps": range(-2, 0)}, None],
non_sequences=[rho, sigma],
n_steps=steps,
strict=True,
name="ma2",
)

def ref_logp(values, rho, sigma):
epsilon_tm2 = 0
epsilon_tm1 = 1
step_logps = np.zeros_like(values)
for t, value in enumerate(values):
mu = epsilon_tm1 * rho[0] + epsilon_tm2 * rho[1]
step_logps[t] = stats.norm.logpdf(value, mu, sigma)
epsilon_tm2 = epsilon_tm1
epsilon_tm1 = value - mu
return step_logps

ma2_vv = ma2.clone()
logp_expr = logp(ma2, ma2_vv)
assert_no_rvs(logp_expr)

ma2_test = rng.normal(size=(steps,))
rho_test = np.array([0.3, 0.7])
sigma_test = 0.9

np.testing.assert_array_almost_equal(
logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
ref_logp(ma2_test, rho_test, sigma_test),
)
28 changes: 1 addition & 27 deletions tests/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,11 @@
import numpy as np

from pytensor import tensor as pt
from pytensor.graph.basic import walk
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.var import TensorVariable
from scipy import stats as stats

from pymc.logprob import factorized_joint_logprob
from pymc.logprob.abstract import (
MeasurableVariable,
get_measurable_outputs,
icdf,
logcdf,
logprob,
)
from pymc.logprob.abstract import get_measurable_outputs, icdf, logcdf, logprob
from pymc.logprob.utils import ignore_logprob


Expand Down Expand Up @@ -82,24 +74,6 @@ def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]
return pt.add(*logprob.values())


def assert_no_rvs(var):
"""Assert that there are no `MeasurableVariable` nodes in a graph."""

def expand(r):
owner = r.owner
if owner:
inputs = list(reversed(owner.inputs))

if isinstance(owner.op, HasInnerGraph):
inputs += owner.op.inner_outputs

return inputs

for v in walk([var], expand, False):
if v.owner and isinstance(v.owner.op, MeasurableVariable):
raise AssertionError(f"Variable {v} is a MeasurableVariable")


def simulate_poiszero_hmm(
N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], seed=None
):
Expand Down