Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, input_broadcastable, new_order):
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])
self.drop = drop

if self.inplace:
self.view_map = {0: [0]}
Expand Down
5 changes: 3 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,11 +1990,12 @@ def _supp_shape_from_params(self, *args, **kwargs):
raise NotImplementedError()

def _infer_shape(self, size, dist_params, param_shapes=None):
(a, p, _) = dist_params

a, p, _ = dist_params
if isinstance(p.type, pytensor.tensor.type_other.NoneTypeT):
param_shapes = param_shapes[:1] if param_shapes is not None else None
shape = super()._infer_shape(size, (a,), param_shapes)
else:
param_shapes = param_shapes[:2] if param_shapes is not None else None
shape = super()._infer_shape(size, (a, p), param_shapes)

return shape
Expand Down
13 changes: 13 additions & 0 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ def _infer_shape(
size_len = get_vector_length(size)

if size_len > 0:

# Fail early when size is incompatible with parameters
for i, (param, param_ndim_supp) in enumerate(
zip(dist_params, self.ndims_params)
):
param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp
if param_batched_dims > size_len:
raise ValueError(
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
f"Size length must be 0 or >= {param_batched_dims}"
)

if self.ndim_supp == 0:
return size
else:
Expand Down
178 changes: 52 additions & 126 deletions pytensor/tensor/random/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
Expand Down Expand Up @@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):

For example, ``normal(mu, std).T == normal(mu.T, std.T)``.

The basic idea behind this rewrite is that we need to separate the
``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
distinct sub-spaces: the (set of independent) parameters and ``size``
(i.e. replications) sub-spaces.

If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
don't do anything.

Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
those sub-spaces, we can break it apart and apply the parameter-space
``DimShuffle`` to the distribution parameters, and then apply the
replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
particularly simple rearranging of a tuple, but the former requires a
little more work.

TODO: Currently, multivariate support for this rewrite is disabled.
This rewrite is only applicable when the Dimshuffle operation does
not affect support dimensions.

TODO: Support dimension dropping
"""

ds_op = node.op
Expand All @@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
base_rv = node.inputs[0]
rv_node = base_rv.owner

if not (
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
# Dimshuffle which drop dimensions not supported yet
if ds_op.drop:
return False

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
rv = rv_node.default_output()

# We need to know the dimensions that were *not* added by the `size`
# parameter (i.e. the dimensions corresponding to independent variates with
# different parameter values)
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)

# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then we can apply
# this rewrite.

ds_new_order = ds_op.new_order
# Create a map from old index order to new/`DimShuffled` index order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]

# Find the index at which the replications/independents split occurs
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)

ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_in_ind_space = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)
# Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
if (shuffled_dims | augmented_dims) & supp_dims:
return False

if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims):
# If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False

# Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least)
has_size = get_vector_length(size) > 0
new_size = (
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order]
if has_size
else size
batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)

# Make size explicit
missing_size_dims = batched_dims - get_vector_length(size)
if missing_size_dims > 0:
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
size = full_size[:missing_size_dims] + tuple(size)

# Update the size to reflect the DimShuffled dimensions
new_size = [
constant(1, dtype="int64") if o == "x" else size[o]
for o in batched_dims_ds_order
]

# Updates the params to reflect the Dimshuffled dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft = batched_dims - (param.ndim - param_ndim_supp)
if padleft > 0:
param = shape_padleft(param, padleft)

# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order = batched_dims_ds_order + tuple(
range(batched_dims, batched_dims + param_ndim_supp)
)
new_dist_params.append(param.dimshuffle(param_new_order))

# Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset)
if ds_ind_new_dims:
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
]

if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
# Additional broadcast dimensions need to be added to the
# independent dimensions (i.e. parameters), since there's no
# `size` to which they can be added
rv_params_new_order = (
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
)
else:
# This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s)
rv_params_new_order = ds_new_order

# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order

new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)

if config.compute_test_value != "off":
compute_test_value(new_node)

out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)

ds_in_reps_space = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
)

if ds_in_reps_space:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]

new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)

if config.compute_test_value != "off":
compute_test_value(new_node)

out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
if config.compute_test_value != "off":
compute_test_value(new_node)

return False
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]


@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def normalize_size_param(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")

if not isinstance(size, Constant):
# This should help ensure that the length of non-constant `size`s
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,19 @@ def test_choice_samples():
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)


def test_choice_infer_shape():
node = choice([0, 1]).owner
res = node.op._infer_shape((), node.inputs[3:], None)
assert tuple(res.eval()) == ()

node = choice([0, 1]).owner
# The param_shape of a NoneConst is None, during shape_inference
res = node.op._infer_shape(
(), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape)
)
assert tuple(res.eval()) == ()


def test_permutation_samples():
compare_sample_values(
permutation,
Expand Down
17 changes: 17 additions & 0 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def test_RandomVariable_bcast():
res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)

res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3))
assert res.broadcastable == (True, False)


def test_RandomVariable_bcast_specify_shape():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
Expand Down Expand Up @@ -214,3 +217,17 @@ def test_random_maker_ops_no_seed():
z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.Generator)


def test_RandomVariable_incompatible_size():
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((1, 3)), 1, size=(3,))

rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
Loading