diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 8b2cac3e97..bdef569905 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -1,7 +1,10 @@ +from itertools import zip_longest + from pytensor.compile import optdb from pytensor.configdefaults import config from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import in2out, node_rewriter +from pytensor.tensor import NoneConst from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import broadcast_to @@ -17,6 +20,7 @@ get_idx_list, indexed_result_shape, ) +from pytensor.tensor.type_other import SliceType def is_rv_used_in_graph(base_rv, node, fgraph): @@ -196,37 +200,11 @@ def local_dimshuffle_rv_lift(fgraph, node): def local_subtensor_rv_lift(fgraph, node): """Lift a ``*Subtensor`` through ``RandomVariable`` inputs. - In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions - need to be separated into distinct replication-space and (independent) - parameter-space ``*Subtensor``s. - - The replication-space ``*Subtensor`` can be used to determine a - sub/super-set of the replication-space and, thus, a "smaller"/"larger" - ``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and - applied to the distribution parameters. - - Consider the following example graph: - ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The - ``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``, - which correspond to all three ``size`` dimensions. Now, depending on the - broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op`` - could be reducing the ``size`` parameter and/or sub-setting the independent - ``mu`` and ``std`` parameters. Only once the dimensions are properly - separated into the two replication/parameter subspaces can we determine how - the ``*Subtensor`` indices are distributed. - For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]`` - could become - ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))`` - if ``mu.shape == std.shape == ()`` - - ``normal`` is a rather simple case, because it's univariate. Multivariate - cases require a mapping between the parameter space and the image of the - random variable. This may not always be possible, but for many common - distributions it is. For example, the dimensions of the multivariate - normal's image can be mapped directly to each dimension of its parameters. - We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]`` - into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. + For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``. + This rewrite also applies to multivariate distributions as long + as indexing does not happen within core dimensions, such as in + ``mvnormal(mu, cov, size=(2,))[0, 0]``. """ st_op = node.op @@ -234,103 +212,92 @@ def local_subtensor_rv_lift(fgraph, node): if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): return False - base_rv = node.inputs[0] + rv = node.inputs[0] + rv_node = rv.owner - rv_node = base_rv.owner if not (rv_node and isinstance(rv_node.op, RandomVariable)): return False - # If no one else is using the underlying `RandomVariable`, then we can - # do this; otherwise, the graph would be internally inconsistent. - if is_rv_used_in_graph(base_rv, node, fgraph): - return False - rv_op = rv_node.op rng, size, dtype, *dist_params = rv_node.inputs - # TODO: Remove this once the multi-dimensional changes described below are - # in place. - if rv_op.ndim_supp > 0: - return False - - rv_op = base_rv.owner.op - rng, size, dtype, *dist_params = base_rv.owner.inputs - + # Parse indices idx_list = getattr(st_op, "idx_list", None) if idx_list: cdata = get_idx_list(node.inputs, idx_list) else: cdata = node.inputs[1:] - st_indices, st_is_bool = zip( *tuple( (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata ) ) - # We need to separate dimensions into replications and independents - num_ind_dims = None - if len(dist_params) == 1: - num_ind_dims = dist_params[0].ndim - else: - # When there is more than one distribution parameter, assume that all - # of them will broadcast to the maximum number of dimensions - num_ind_dims = max(d.ndim for d in dist_params) - - reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) - - if len(st_indices) > reps_ind_split_idx: - # These are the indices that need to be applied to the parameters - ind_indices = tuple(st_indices[reps_ind_split_idx:]) - - # We need to broadcast the parameters before applying the `*Subtensor*` - # with these indices, because the indices could be referencing broadcast - # dimensions that don't exist (yet) - bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params) - - # TODO: For multidimensional distributions, we need a map that tells us - # which dimensions of the parameters need to be indexed. - # - # For example, `multivariate_normal` would have the following: - # `RandomVariable.param_to_image_dims = ((0,), (0, 1))` - # - # I.e. the first parameter's (i.e. mean's) first dimension maps directly to - # the dimension of the RV's image, and its second parameter's - # (i.e. covariance's) first and second dimensions map directly to the - # dimension of the RV's image. - - args_lifted = tuple(p[ind_indices] for p in bcast_dist_params) - else: - # In this case, no indexing is applied to the parameters; only the - # `size` parameter is affected. - args_lifted = dist_params + # Check that indexing does not act on support dims + batched_ndims = rv.ndim - rv_op.ndim_supp + if len(st_indices) > batched_ndims: + # If the last indexes are just dummy `slice(None)` we discard them + st_is_bool = st_is_bool[:batched_ndims] + st_indices, supp_indices = ( + st_indices[:batched_ndims], + st_indices[batched_ndims:], + ) + for index in supp_indices: + if not ( + isinstance(index.type, SliceType) + and all(NoneConst.equals(i) for i in index.owner.inputs) + ): + return False + + # If no one else is using the underlying `RandomVariable`, then we can + # do this; otherwise, the graph would be internally inconsistent. + if is_rv_used_in_graph(rv, node, fgraph): + return False + # Update the size to reflect the indexed dimensions # TODO: Could use `ShapeFeature` info. We would need to be sure that # `node` isn't in the results, though. # if hasattr(fgraph, "shape_feature"): # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) # else: - output_shape = indexed_result_shape(base_rv.shape, st_indices) - - size_lifted = ( - output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] + output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices) + new_size_ignoring_boolean = ( + output_shape_ignoring_bool + if rv_op.ndim_supp == 0 + else output_shape_ignoring_bool[: -rv_op.ndim_supp] ) - # Boolean indices can actually change the `size` value (compared to just - # *which* dimensions of `size` are used). + # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). + # The `indexed_result_shape` helper does not consider this if any(st_is_bool): - size_lifted = tuple( + new_size = tuple( at_sum(idx) if is_bool else s - for s, is_bool, idx in zip( - size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] + for s, is_bool, idx in zip_longest( + new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False ) ) + else: + new_size = new_size_ignoring_boolean - new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) - _, new_rv = new_node.outputs + # Update the parameters to reflect the indexed dimensions + new_dist_params = [] + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + # Apply indexing on the batched dimensions of the parameter + batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp) + batched_param = shape_padleft(param, batched_param_dims_missing) + batched_st_indices = [] + for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape): + # If we have a degenerate dimension indexing it should always do the job + if batched_param_shape == 1: + batched_st_indices.append(0) + else: + batched_st_indices.append(st_index) + new_dist_params.append(batched_param[tuple(batched_st_indices)]) + + # Create new RV + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) + new_rv = new_node.default_output() - # Calling `Op.make_node` directly circumvents test value computations, so - # we need to compute the test values manually if config.compute_test_value != "off": compute_test_value(new_node) diff --git a/tests/tensor/random/rewriting/__init__.py b/tests/tensor/random/rewriting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tensor/random/test_rewriting.py b/tests/tensor/random/rewriting/test_basic.py similarity index 87% rename from tests/tensor/random/test_rewriting.py rename to tests/tensor/random/rewriting/test_basic.py index e20d959881..ef9cf8b3b3 100644 --- a/tests/tensor/random/test_rewriting.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -12,6 +12,7 @@ from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.basic import ( + categorical, dirichlet, multinomial, multivariate_normal, @@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv( rewrite, op_fn, dist_op, dist_params, size, rng, name=None ): dist_params_at = [] - for p in dist_params: - p_at = at.as_tensor(p).type() + for i, p in enumerate(dist_params): + p_at = at.as_tensor(p).type(f"p_{i}") p_at.tag.test_value = p dist_params_at.append(p_at) @@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (3, 2, 2), ), - # A multi-dimensional case + # Only one distribution parameter + ( + (0,), + True, + poisson, + (np.array([[1, 2], [3, 4]], dtype=config.floatX),), + (3, 2, 2), + ), + # Univariate distribution with vector parameters + ( + (np.array([0, 2]),), + True, + categorical, + (np.array([0.0, 0.0, 1.0], dtype=config.floatX),), + (4,), + ), + ( + (np.array([True, False, True, True]),), + True, + categorical, + (np.array([0.0, 0.0, 1.0], dtype=config.floatX),), + (4,), + ), + ( + (np.array([True, False, True]),), + True, + categorical, + ( + np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=config.floatX, + ), + ), + (), + ), + ( + ( + slice(None), + np.array([True, False, True]), + ), + True, + categorical, + ( + np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=config.floatX, + ), + ), + (4, 3), + ), + # Boolean indexing where output is empty + ( + (np.array([False, False]),), + True, + normal, + (np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),), + (2, 3), + ), ( + (np.array([False, False]),), + True, + categorical, + ( + np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=config.floatX, + ), + ), + (2, 3), + ), + # Multivariate cases, indexing only supported if it does not affect core dimensions + ( + # Indexing dips into core dimension (np.array([1]), 0), False, multivariate_normal, @@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (), ), - # Only one distribution parameter ( - (0,), + (np.array([0, 2]),), True, - poisson, - (np.array([[1, 2], [3, 4]], dtype=config.floatX),), - (3, 2, 2), + multivariate_normal, + ( + np.array( + [[-100, -125, -150], [0, 0, 0], [200, 225, 250]], + dtype=config.floatX, + ), + np.eye(3, dtype=config.floatX) * 1e-6, + ), + (), + ), + ( + (np.array([True, False, True]), slice(None)), + True, + multivariate_normal, + ( + np.array([200, 250], dtype=config.floatX), + # Second covariance is invalid, to test it is not chosen + np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX) + * 1e-6, + ), + (3,), ), ], )