Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 62 additions & 95 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from itertools import zip_longest

from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
Expand All @@ -17,6 +20,7 @@
get_idx_list,
indexed_result_shape,
)
from pytensor.tensor.type_other import SliceType


def is_rv_used_in_graph(base_rv, node, fgraph):
Expand Down Expand Up @@ -196,141 +200,104 @@ def local_dimshuffle_rv_lift(fgraph, node):
def local_subtensor_rv_lift(fgraph, node):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.

In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions
need to be separated into distinct replication-space and (independent)
parameter-space ``*Subtensor``s.

The replication-space ``*Subtensor`` can be used to determine a
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
applied to the distribution parameters.

Consider the following example graph:
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``,
which correspond to all three ``size`` dimensions. Now, depending on the
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op``
could be reducing the ``size`` parameter and/or sub-setting the independent
``mu`` and ``std`` parameters. Only once the dimensions are properly
separated into the two replication/parameter subspaces can we determine how
the ``*Subtensor`` indices are distributed.
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
could become
``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
if ``mu.shape == std.shape == ()``

``normal`` is a rather simple case, because it's univariate. Multivariate
cases require a mapping between the parameter space and the image of the
random variable. This may not always be possible, but for many common
distributions it is. For example, the dimensions of the multivariate
normal's image can be mapped directly to each dimension of its parameters.
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``.
For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``.

This rewrite also applies to multivariate distributions as long
as indexing does not happen within core dimensions, such as in
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""

st_op = node.op

if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False

base_rv = node.inputs[0]
rv = node.inputs[0]
rv_node = rv.owner

rv_node = base_rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs

# TODO: Remove this once the multi-dimensional changes described below are
# in place.
if rv_op.ndim_supp > 0:
return False

rv_op = base_rv.owner.op
rng, size, dtype, *dist_params = base_rv.owner.inputs

# Parse indices
idx_list = getattr(st_op, "idx_list", None)
if idx_list:
cdata = get_idx_list(node.inputs, idx_list)
else:
cdata = node.inputs[1:]

st_indices, st_is_bool = zip(
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
)
)

# We need to separate dimensions into replications and independents
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)

reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp)

if len(st_indices) > reps_ind_split_idx:
# These are the indices that need to be applied to the parameters
ind_indices = tuple(st_indices[reps_ind_split_idx:])

# We need to broadcast the parameters before applying the `*Subtensor*`
# with these indices, because the indices could be referencing broadcast
# dimensions that don't exist (yet)
bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params)

# TODO: For multidimensional distributions, we need a map that tells us
# which dimensions of the parameters need to be indexed.
#
# For example, `multivariate_normal` would have the following:
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
#
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
# the dimension of the RV's image, and its second parameter's
# (i.e. covariance's) first and second dimensions map directly to the
# dimension of the RV's image.

args_lifted = tuple(p[ind_indices] for p in bcast_dist_params)
else:
# In this case, no indexing is applied to the parameters; only the
# `size` parameter is affected.
args_lifted = dist_params
# Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims:
# If the last indexes are just dummy `slice(None)` we discard them
st_is_bool = st_is_bool[:batched_ndims]
st_indices, supp_indices = (
st_indices[:batched_ndims],
st_indices[batched_ndims:],
)
for index in supp_indices:
if not (
isinstance(index.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs)
):
return False

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False

# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape = indexed_result_shape(base_rv.shape, st_indices)

size_lifted = (
output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp]
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)

# Boolean indices can actually change the `size` value (compared to just
# *which* dimensions of `size` are used).
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
# The `indexed_result_shape` helper does not consider this
if any(st_is_bool):
size_lifted = tuple(
new_size = tuple(
at_sum(idx) if is_bool else s
for s, is_bool, idx in zip(
size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)]
for s, is_bool, idx in zip_longest(
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
)
)
else:
new_size = new_size_ignoring_boolean

new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted)
_, new_rv = new_node.outputs
# Update the parameters to reflect the indexed dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
else:
batched_st_indices.append(st_index)
new_dist_params.append(batched_param[tuple(batched_st_indices)])

# Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output()

# Calling `Op.make_node` directly circumvents test value computations, so
# we need to compute the test values manually
if config.compute_test_value != "off":
compute_test_value(new_node)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
categorical,
dirichlet,
multinomial,
multivariate_normal,
Expand All @@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
):
dist_params_at = []
for p in dist_params:
p_at = at.as_tensor(p).type()
for i, p in enumerate(dist_params):
p_at = at.as_tensor(p).type(f"p_{i}")
p_at.tag.test_value = p
dist_params_at.append(p_at)

Expand Down Expand Up @@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(3, 2, 2),
),
# A multi-dimensional case
# Only one distribution parameter
(
(0,),
True,
poisson,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
(3, 2, 2),
),
# Univariate distribution with vector parameters
(
(np.array([0, 2]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True, True]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(),
),
(
(
slice(None),
np.array([True, False, True]),
),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(4, 3),
),
# Boolean indexing where output is empty
(
(np.array([False, False]),),
True,
normal,
(np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),),
(2, 3),
),
(
(np.array([False, False]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(2, 3),
),
# Multivariate cases, indexing only supported if it does not affect core dimensions
(
# Indexing dips into core dimension
(np.array([1]), 0),
False,
multivariate_normal,
Expand All @@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(),
),
# Only one distribution parameter
(
(0,),
(np.array([0, 2]),),
True,
poisson,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
(3, 2, 2),
multivariate_normal,
(
np.array(
[[-100, -125, -150], [0, 0, 0], [200, 225, 250]],
dtype=config.floatX,
),
np.eye(3, dtype=config.floatX) * 1e-6,
),
(),
),
(
(np.array([True, False, True]), slice(None)),
True,
multivariate_normal,
(
np.array([200, 250], dtype=config.floatX),
# Second covariance is invalid, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
* 1e-6,
),
(3,),
),
],
)
Expand Down