diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 9e63399706..0db7871892 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -173,6 +173,7 @@ def __init__(self, input_broadcastable, new_order): # List of dimensions of the output that are broadcastable and were not # in the original input self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"]) + self.drop = drop if self.inplace: self.view_map = {0: [0]} diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 23252c8796..9062727369 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1990,11 +1990,12 @@ def _supp_shape_from_params(self, *args, **kwargs): raise NotImplementedError() def _infer_shape(self, size, dist_params, param_shapes=None): - (a, p, _) = dist_params - + a, p, _ = dist_params if isinstance(p.type, pytensor.tensor.type_other.NoneTypeT): + param_shapes = param_shapes[:1] if param_shapes is not None else None shape = super()._infer_shape(size, (a,), param_shapes) else: + param_shapes = param_shapes[:2] if param_shapes is not None else None shape = super()._infer_shape(size, (a, p), param_shapes) return shape diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6460da2e54..ff48a07b70 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -192,6 +192,19 @@ def _infer_shape( size_len = get_vector_length(size) if size_len > 0: + + # Fail early when size is incompatible with parameters + for i, (param, param_ndim_supp) in enumerate( + zip(dist_params, self.ndims_params) + ): + param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp + if param_batched_dims > size_len: + raise ValueError( + f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n" + f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. " + f"Size length must be 0 or >= {param_batched_dims}" + ) + if self.ndim_supp == 0: return size else: diff --git a/pytensor/tensor/random/rewriting.py b/pytensor/tensor/random/rewriting.py index 2e5194b727..bef13a9189 100644 --- a/pytensor/tensor/random/rewriting.py +++ b/pytensor/tensor/random/rewriting.py @@ -8,7 +8,7 @@ from pytensor.tensor.math import sum as at_sum from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import broadcast_params -from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.shape import Shape, Shape_i, shape_padleft from pytensor.tensor.subtensor import ( AdvancedSubtensor, AdvancedSubtensor1, @@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node): For example, ``normal(mu, std).T == normal(mu.T, std.T)``. - The basic idea behind this rewrite is that we need to separate the - ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two - distinct sub-spaces: the (set of independent) parameters and ``size`` - (i.e. replications) sub-spaces. - - If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we - don't do anything. - - Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of - those sub-spaces, we can break it apart and apply the parameter-space - ``DimShuffle`` to the distribution parameters, and then apply the - replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a - particularly simple rearranging of a tuple, but the former requires a - little more work. - - TODO: Currently, multivariate support for this rewrite is disabled. + This rewrite is only applicable when the Dimshuffle operation does + not affect support dimensions. + TODO: Support dimension dropping """ ds_op = node.op @@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node): base_rv = node.inputs[0] rv_node = base_rv.owner - if not ( - rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0 - ): + if not (rv_node and isinstance(rv_node.op, RandomVariable)): return False - # If no one else is using the underlying `RandomVariable`, then we can - # do this; otherwise, the graph would be internally inconsistent. - if is_rv_used_in_graph(base_rv, node, fgraph): + # Dimshuffle which drop dimensions not supported yet + if ds_op.drop: return False rv_op = rv_node.op rng, size, dtype, *dist_params = rv_node.inputs + rv = rv_node.default_output() - # We need to know the dimensions that were *not* added by the `size` - # parameter (i.e. the dimensions corresponding to independent variates with - # different parameter values) - num_ind_dims = None - if len(dist_params) == 1: - num_ind_dims = dist_params[0].ndim - else: - # When there is more than one distribution parameter, assume that all - # of them will broadcast to the maximum number of dimensions - num_ind_dims = max(d.ndim for d in dist_params) - - # If the indices in `ds_new_order` are entirely within the replication - # indices group or the independent variates indices group, then we can apply - # this rewrite. - - ds_new_order = ds_op.new_order - # Create a map from old index order to new/`DimShuffled` index order - dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)] - - # Find the index at which the replications/independents split occurs - reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp) - - ds_reps_new_dims = dim_orders[:reps_ind_split_idx] - ds_ind_new_dims = dim_orders[reps_ind_split_idx:] - ds_in_ind_space = ds_ind_new_dims and all( - d >= reps_ind_split_idx for n, d in ds_ind_new_dims - ) + # Check that Dimshuffle does not affect support dims + supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim)) + shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i} + augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment) + if (shuffled_dims | augmented_dims) & supp_dims: + return False - if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims): + # If no one else is using the underlying RandomVariable, then we can + # do this; otherwise, the graph would be internally inconsistent. + if is_rv_used_in_graph(base_rv, node, fgraph): + return False - # Update the `size` array to reflect the `DimShuffle`d dimensions, - # since the trailing dimensions in `size` represent the independent - # variates dimensions (for univariate distributions, at least) - has_size = get_vector_length(size) > 0 - new_size = ( - [constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] - if has_size - else size + batched_dims = rv.ndim - rv_op.ndim_supp + batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims) + + # Make size explicit + missing_size_dims = batched_dims - get_vector_length(size) + if missing_size_dims > 0: + full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) + size = full_size[:missing_size_dims] + tuple(size) + + # Update the size to reflect the DimShuffled dimensions + new_size = [ + constant(1, dtype="int64") if o == "x" else size[o] + for o in batched_dims_ds_order + ] + + # Updates the params to reflect the Dimshuffled dimensions + new_dist_params = [] + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + # Add broadcastable dimensions to the parameters that would have been expanded by the size + padleft = batched_dims - (param.ndim - param_ndim_supp) + if padleft > 0: + param = shape_padleft(param, padleft) + + # Add the parameter support dimension indexes to the batched dimensions Dimshuffle + param_new_order = batched_dims_ds_order + tuple( + range(batched_dims, batched_dims + param_ndim_supp) ) + new_dist_params.append(param.dimshuffle(param_new_order)) - # Compute the new axes parameter(s) for the `DimShuffle` that will be - # applied to the `RandomVariable` parameters (they need to be offset) - if ds_ind_new_dims: - rv_params_new_order = [ - d - reps_ind_split_idx if isinstance(d, int) else d - for d in ds_new_order[ds_ind_new_dims[0][0] :] - ] - - if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0: - # Additional broadcast dimensions need to be added to the - # independent dimensions (i.e. parameters), since there's no - # `size` to which they can be added - rv_params_new_order = ( - list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order - ) - else: - # This case is reached when, for example, `ds_new_order` only - # consists of new broadcastable dimensions (i.e. `"x"`s) - rv_params_new_order = ds_new_order - - # Lift the `DimShuffle`s into the parameters - # NOTE: The parameters might not be broadcasted against each other, so - # we can only apply the parts of the `DimShuffle` that are relevant. - new_dist_params = [] - for d in dist_params: - if d.ndim < len(ds_ind_new_dims): - _rv_params_new_order = [ - o - for o in rv_params_new_order - if (isinstance(o, int) and o < d.ndim) or o == "x" - ] - else: - _rv_params_new_order = rv_params_new_order - - new_dist_params.append( - type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d) - ) - new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) - ds_in_reps_space = ds_reps_new_dims and all( - d < reps_ind_split_idx for n, d in ds_reps_new_dims - ) - - if ds_in_reps_space: - # Update the `size` array to reflect the `DimShuffle`d dimensions. - # There should be no need to `DimShuffle` now. - new_size = [ - constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order - ] - - new_node = rv_op.make_node(rng, new_size, dtype, *dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] + if config.compute_test_value != "off": + compute_test_value(new_node) - return False + out = new_node.outputs[1] + if base_rv.name: + out.name = f"{base_rv.name}_lifted" + return [out] @node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index c6a9344b31..329581f48b 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -134,7 +134,7 @@ def normalize_size_param( "Parameter size must be None, an integer, or a sequence with integers." ) else: - size = cast(as_tensor_variable(size, ndim=1), "int64") + size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64") if not isinstance(size, Constant): # This should help ensure that the length of non-constant `size`s diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4dfb67bddc..815321aa97 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1390,6 +1390,19 @@ def test_choice_samples(): compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True) +def test_choice_infer_shape(): + node = choice([0, 1]).owner + res = node.op._infer_shape((), node.inputs[3:], None) + assert tuple(res.eval()) == () + + node = choice([0, 1]).owner + # The param_shape of a NoneConst is None, during shape_inference + res = node.op._infer_shape( + (), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape) + ) + assert tuple(res.eval()) == () + + def test_permutation_samples(): compare_sample_values( permutation, diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index a67fae83a8..7b9695aded 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -148,6 +148,9 @@ def test_RandomVariable_bcast(): res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64)) assert res.broadcastable == (True,) + res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3)) + assert res.broadcastable == (True, False) + def test_RandomVariable_bcast_specify_shape(): rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) @@ -214,3 +217,17 @@ def test_random_maker_ops_no_seed(): z = function(inputs=[], outputs=[default_rng()])() aes_res = z[0] assert isinstance(aes_res, np.random.Generator) + + +def test_RandomVariable_incompatible_size(): + rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) + with pytest.raises( + ValueError, match="Size length is incompatible with batched dimensions" + ): + rv_op(np.zeros((1, 3)), 1, size=(3,)) + + rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True) + with pytest.raises( + ValueError, match="Size length is incompatible with batched dimensions" + ): + rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) diff --git a/tests/tensor/random/test_rewriting.py b/tests/tensor/random/test_rewriting.py index 856d908501..eee1eb5d17 100644 --- a/tests/tensor/random/test_rewriting.py +++ b/tests/tensor/random/test_rewriting.py @@ -9,6 +9,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.basic import ( dirichlet, @@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv( size_at = [] for s in size: - s_at = iscalar() + # To test DimShuffle with dropping dims we need that size dimension to be constant + if s == 1: + s_at = constant(np.array(1, dtype="int32")) + else: + s_at = iscalar() s_at.tag.test_value = s size_at.append(s_at) @@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ), ( ("x", 1, 0, 2, "x"), - False, + True, normal, ( np.array([[-1, 20], [300, -4000]], dtype=config.floatX), @@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size): (3, 2, 2), 1, ), - # A multi-dimensional case + # Supported multi-dimensional cases + ( + (1, 0, 2), + True, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + ( + (1, 0, "x", 2), + True, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + # Not supported multi-dimensional cases where dimshuffle affects the support dimensionality ( (0, 2, 1), False, @@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size): (3, 2), 1e-3, ), + ( + (0, 1, 2, "x"), + False, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + pytest.param( + (1,), + True, + normal, + (0, 1), + (1, 2), + 1e-3, + marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"), + ), + pytest.param( + (1,), + True, + normal, + ([[0, 0]], 1), + (1, 2), + 1e-3, + marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"), + ), ], ) @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index b34b170f7f..ed964b6091 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -794,7 +794,7 @@ def test_full(self): assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) -def test_infer_broadcastable(): +def test_infer_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): infer_static_shape([constant(1.0)])