From 6b05645fda9a4e7b46917c26ec3675bbaf0b1f43 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Dec 2022 12:56:04 +0100 Subject: [PATCH 1/6] Fix test name --- tests/tensor/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index b34b170f7f..ed964b6091 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -794,7 +794,7 @@ def test_full(self): assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) -def test_infer_broadcastable(): +def test_infer_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): infer_static_shape([constant(1.0)]) From df2f8a500117cb27496a8c7f70fd457c00931d10 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Dec 2022 12:52:18 +0100 Subject: [PATCH 2/6] Apply casting in as_tensor_variable in normalize_size_param This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op --- pytensor/tensor/random/utils.py | 2 +- tests/tensor/random/test_op.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index c6a9344b31..329581f48b 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -134,7 +134,7 @@ def normalize_size_param( "Parameter size must be None, an integer, or a sequence with integers." ) else: - size = cast(as_tensor_variable(size, ndim=1), "int64") + size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64") if not isinstance(size, Constant): # This should help ensure that the length of non-constant `size`s diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index a67fae83a8..16d54dddde 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -148,6 +148,9 @@ def test_RandomVariable_bcast(): res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64)) assert res.broadcastable == (True,) + res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3)) + assert res.broadcastable == (True, False) + def test_RandomVariable_bcast_specify_shape(): rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) From 8ebd2c3891e493db99d8beefa31ffcad3a76b38a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 4 Dec 2022 11:55:44 +0100 Subject: [PATCH 3/6] Fix shape_inference of `ChoiceRV` when param_shapes are provided --- pytensor/tensor/random/basic.py | 5 +++-- tests/tensor/random/test_basic.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 23252c8796..9062727369 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1990,11 +1990,12 @@ def _supp_shape_from_params(self, *args, **kwargs): raise NotImplementedError() def _infer_shape(self, size, dist_params, param_shapes=None): - (a, p, _) = dist_params - + a, p, _ = dist_params if isinstance(p.type, pytensor.tensor.type_other.NoneTypeT): + param_shapes = param_shapes[:1] if param_shapes is not None else None shape = super()._infer_shape(size, (a,), param_shapes) else: + param_shapes = param_shapes[:2] if param_shapes is not None else None shape = super()._infer_shape(size, (a, p), param_shapes) return shape diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4dfb67bddc..815321aa97 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1390,6 +1390,19 @@ def test_choice_samples(): compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True) +def test_choice_infer_shape(): + node = choice([0, 1]).owner + res = node.op._infer_shape((), node.inputs[3:], None) + assert tuple(res.eval()) == () + + node = choice([0, 1]).owner + # The param_shape of a NoneConst is None, during shape_inference + res = node.op._infer_shape( + (), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape) + ) + assert tuple(res.eval()) == () + + def test_permutation_samples(): compare_sample_values( permutation, From eff4b6b1950e4e4fab4603ee92948a297da74424 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Dec 2022 15:50:20 +0100 Subject: [PATCH 4/6] Fail early in RandomVariable.make_node when size is incompatible with parameters dimensionality --- pytensor/tensor/random/op.py | 13 +++++++++++++ tests/tensor/random/test_op.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6460da2e54..ff48a07b70 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -192,6 +192,19 @@ def _infer_shape( size_len = get_vector_length(size) if size_len > 0: + + # Fail early when size is incompatible with parameters + for i, (param, param_ndim_supp) in enumerate( + zip(dist_params, self.ndims_params) + ): + param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp + if param_batched_dims > size_len: + raise ValueError( + f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n" + f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. " + f"Size length must be 0 or >= {param_batched_dims}" + ) + if self.ndim_supp == 0: return size else: diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 16d54dddde..7b9695aded 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -217,3 +217,17 @@ def test_random_maker_ops_no_seed(): z = function(inputs=[], outputs=[default_rng()])() aes_res = z[0] assert isinstance(aes_res, np.random.Generator) + + +def test_RandomVariable_incompatible_size(): + rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) + with pytest.raises( + ValueError, match="Size length is incompatible with batched dimensions" + ): + rv_op(np.zeros((1, 3)), 1, size=(3,)) + + rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True) + with pytest.raises( + ValueError, match="Size length is incompatible with batched dimensions" + ): + rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) From 97c0b7259d22e2365689de4ac74dc46ea51c639a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Dec 2022 15:17:32 +0100 Subject: [PATCH 5/6] Add drop property to Dimshuffle --- pytensor/tensor/elemwise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 9e63399706..0db7871892 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -173,6 +173,7 @@ def __init__(self, input_broadcastable, new_order): # List of dimensions of the output that are broadcastable and were not # in the original input self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"]) + self.drop = drop if self.inplace: self.view_map = {0: [0]} From e59ddb71c09063846228486cb4e824ac65f3c295 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 2 Dec 2022 15:26:52 +0100 Subject: [PATCH 6/6] Expand and simplify `local_dimshuffle_rv_lift` * The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires: 1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway. 2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together. * The rewrite now works with Multivariate RVs * The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite --- pytensor/tensor/random/rewriting.py | 178 ++++++++------------------ tests/tensor/random/test_rewriting.py | 63 ++++++++- 2 files changed, 112 insertions(+), 129 deletions(-) diff --git a/pytensor/tensor/random/rewriting.py b/pytensor/tensor/random/rewriting.py index 2e5194b727..bef13a9189 100644 --- a/pytensor/tensor/random/rewriting.py +++ b/pytensor/tensor/random/rewriting.py @@ -8,7 +8,7 @@ from pytensor.tensor.math import sum as at_sum from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import broadcast_params -from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.shape import Shape, Shape_i, shape_padleft from pytensor.tensor.subtensor import ( AdvancedSubtensor, AdvancedSubtensor1, @@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node): For example, ``normal(mu, std).T == normal(mu.T, std.T)``. - The basic idea behind this rewrite is that we need to separate the - ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two - distinct sub-spaces: the (set of independent) parameters and ``size`` - (i.e. replications) sub-spaces. - - If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we - don't do anything. - - Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of - those sub-spaces, we can break it apart and apply the parameter-space - ``DimShuffle`` to the distribution parameters, and then apply the - replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a - particularly simple rearranging of a tuple, but the former requires a - little more work. - - TODO: Currently, multivariate support for this rewrite is disabled. + This rewrite is only applicable when the Dimshuffle operation does + not affect support dimensions. + TODO: Support dimension dropping """ ds_op = node.op @@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node): base_rv = node.inputs[0] rv_node = base_rv.owner - if not ( - rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0 - ): + if not (rv_node and isinstance(rv_node.op, RandomVariable)): return False - # If no one else is using the underlying `RandomVariable`, then we can - # do this; otherwise, the graph would be internally inconsistent. - if is_rv_used_in_graph(base_rv, node, fgraph): + # Dimshuffle which drop dimensions not supported yet + if ds_op.drop: return False rv_op = rv_node.op rng, size, dtype, *dist_params = rv_node.inputs + rv = rv_node.default_output() - # We need to know the dimensions that were *not* added by the `size` - # parameter (i.e. the dimensions corresponding to independent variates with - # different parameter values) - num_ind_dims = None - if len(dist_params) == 1: - num_ind_dims = dist_params[0].ndim - else: - # When there is more than one distribution parameter, assume that all - # of them will broadcast to the maximum number of dimensions - num_ind_dims = max(d.ndim for d in dist_params) - - # If the indices in `ds_new_order` are entirely within the replication - # indices group or the independent variates indices group, then we can apply - # this rewrite. - - ds_new_order = ds_op.new_order - # Create a map from old index order to new/`DimShuffled` index order - dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)] - - # Find the index at which the replications/independents split occurs - reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp) - - ds_reps_new_dims = dim_orders[:reps_ind_split_idx] - ds_ind_new_dims = dim_orders[reps_ind_split_idx:] - ds_in_ind_space = ds_ind_new_dims and all( - d >= reps_ind_split_idx for n, d in ds_ind_new_dims - ) + # Check that Dimshuffle does not affect support dims + supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim)) + shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i} + augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment) + if (shuffled_dims | augmented_dims) & supp_dims: + return False - if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims): + # If no one else is using the underlying RandomVariable, then we can + # do this; otherwise, the graph would be internally inconsistent. + if is_rv_used_in_graph(base_rv, node, fgraph): + return False - # Update the `size` array to reflect the `DimShuffle`d dimensions, - # since the trailing dimensions in `size` represent the independent - # variates dimensions (for univariate distributions, at least) - has_size = get_vector_length(size) > 0 - new_size = ( - [constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] - if has_size - else size + batched_dims = rv.ndim - rv_op.ndim_supp + batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims) + + # Make size explicit + missing_size_dims = batched_dims - get_vector_length(size) + if missing_size_dims > 0: + full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) + size = full_size[:missing_size_dims] + tuple(size) + + # Update the size to reflect the DimShuffled dimensions + new_size = [ + constant(1, dtype="int64") if o == "x" else size[o] + for o in batched_dims_ds_order + ] + + # Updates the params to reflect the Dimshuffled dimensions + new_dist_params = [] + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + # Add broadcastable dimensions to the parameters that would have been expanded by the size + padleft = batched_dims - (param.ndim - param_ndim_supp) + if padleft > 0: + param = shape_padleft(param, padleft) + + # Add the parameter support dimension indexes to the batched dimensions Dimshuffle + param_new_order = batched_dims_ds_order + tuple( + range(batched_dims, batched_dims + param_ndim_supp) ) + new_dist_params.append(param.dimshuffle(param_new_order)) - # Compute the new axes parameter(s) for the `DimShuffle` that will be - # applied to the `RandomVariable` parameters (they need to be offset) - if ds_ind_new_dims: - rv_params_new_order = [ - d - reps_ind_split_idx if isinstance(d, int) else d - for d in ds_new_order[ds_ind_new_dims[0][0] :] - ] - - if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0: - # Additional broadcast dimensions need to be added to the - # independent dimensions (i.e. parameters), since there's no - # `size` to which they can be added - rv_params_new_order = ( - list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order - ) - else: - # This case is reached when, for example, `ds_new_order` only - # consists of new broadcastable dimensions (i.e. `"x"`s) - rv_params_new_order = ds_new_order - - # Lift the `DimShuffle`s into the parameters - # NOTE: The parameters might not be broadcasted against each other, so - # we can only apply the parts of the `DimShuffle` that are relevant. - new_dist_params = [] - for d in dist_params: - if d.ndim < len(ds_ind_new_dims): - _rv_params_new_order = [ - o - for o in rv_params_new_order - if (isinstance(o, int) and o < d.ndim) or o == "x" - ] - else: - _rv_params_new_order = rv_params_new_order - - new_dist_params.append( - type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d) - ) - new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) - ds_in_reps_space = ds_reps_new_dims and all( - d < reps_ind_split_idx for n, d in ds_reps_new_dims - ) - - if ds_in_reps_space: - # Update the `size` array to reflect the `DimShuffle`d dimensions. - # There should be no need to `DimShuffle` now. - new_size = [ - constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order - ] - - new_node = rv_op.make_node(rng, new_size, dtype, *dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] + if config.compute_test_value != "off": + compute_test_value(new_node) - return False + out = new_node.outputs[1] + if base_rv.name: + out.name = f"{base_rv.name}_lifted" + return [out] @node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) diff --git a/tests/tensor/random/test_rewriting.py b/tests/tensor/random/test_rewriting.py index 856d908501..eee1eb5d17 100644 --- a/tests/tensor/random/test_rewriting.py +++ b/tests/tensor/random/test_rewriting.py @@ -9,6 +9,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.basic import ( dirichlet, @@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv( size_at = [] for s in size: - s_at = iscalar() + # To test DimShuffle with dropping dims we need that size dimension to be constant + if s == 1: + s_at = constant(np.array(1, dtype="int32")) + else: + s_at = iscalar() s_at.tag.test_value = s size_at.append(s_at) @@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ), ( ("x", 1, 0, 2, "x"), - False, + True, normal, ( np.array([[-1, 20], [300, -4000]], dtype=config.floatX), @@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size): (3, 2, 2), 1, ), - # A multi-dimensional case + # Supported multi-dimensional cases + ( + (1, 0, 2), + True, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + ( + (1, 0, "x", 2), + True, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + # Not supported multi-dimensional cases where dimshuffle affects the support dimensionality ( (0, 2, 1), False, @@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size): (3, 2), 1e-3, ), + ( + (0, 1, 2, "x"), + False, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (3, 2), + 1e-3, + ), + pytest.param( + (1,), + True, + normal, + (0, 1), + (1, 2), + 1e-3, + marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"), + ), + pytest.param( + (1,), + True, + normal, + ([[0, 0]], 1), + (1, 2), + 1e-3, + marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"), + ), ], ) @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")