| 
10 | 10 | from collections.abc import Sequence  | 
11 | 11 | from functools import partial  | 
12 | 12 | from numbers import Number  | 
13 |  | -from typing import TYPE_CHECKING  | 
 | 13 | +from typing import TYPE_CHECKING, Union  | 
14 | 14 | from typing import cast as type_cast  | 
15 | 15 | 
 
  | 
16 | 16 | import numpy as np  | 
 | 
33 | 33 | from pytensor.link.c.op import COp  | 
34 | 34 | from pytensor.link.c.params_type import ParamsType  | 
35 | 35 | from pytensor.printing import Printer, min_informative_str, pprint, set_precedence  | 
36 |  | -from pytensor.raise_op import CheckAndRaise, assert_op  | 
 | 36 | +from pytensor.raise_op import CheckAndRaise  | 
37 | 37 | from pytensor.scalar import int32  | 
38 | 38 | from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable  | 
39 | 39 | from pytensor.tensor import (  | 
@@ -3084,87 +3084,72 @@ def flatten(x, ndim=1):  | 
3084 | 3084 |     return x_reshaped  | 
3085 | 3085 | 
 
  | 
3086 | 3086 | 
 
  | 
3087 |  | -def tile(x, reps, ndim=None):  | 
 | 3087 | +def tile(  | 
 | 3088 | +    A: "TensorLike", reps: Union[Sequence[int, "TensorLike"], "TensorLike"]  | 
 | 3089 | +) -> TensorVariable:  | 
3088 | 3090 |     """  | 
3089 |  | -    Tile input array `x` according to `reps`.  | 
 | 3091 | +    Tile input array `A` according to `reps`.  | 
3090 | 3092 | 
  | 
3091 | 3093 |     See the docstring of `numpy.tile` for details.  | 
3092 | 3094 | 
  | 
3093 |  | -    'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),  | 
3094 |  | -    symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())  | 
3095 |  | -    or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).  | 
3096 |  | -
  | 
3097 |  | -    ndim is the number of the dimensions of the output, if it is provided, ndim  | 
3098 |  | -    should be equal or larger than x.ndim and len(reps), otherwise, we will use  | 
3099 |  | -    max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to  | 
3100 |  | -    be provided.  | 
3101 |  | -
  | 
 | 3095 | +    If `reps` is a PyTensor vector, it's length must be statically known.  | 
 | 3096 | +    You can use `specify_shape` to set the length.  | 
3102 | 3097 |     """  | 
3103 |  | -    from pytensor.tensor.math import ge  | 
3104 | 3098 | 
 
  | 
3105 |  | -    _x = as_tensor_variable(x)  | 
3106 |  | -    if ndim is not None and ndim < _x.ndim:  | 
3107 |  | -        raise ValueError("ndim should be equal or larger than _x.ndim")  | 
 | 3099 | +    A = as_tensor_variable(A)  | 
3108 | 3100 | 
 
  | 
3109 |  | -    # If reps is a scalar, integer or vector, we convert it to a list.  | 
 | 3101 | +    # Convert symbolic reps to a tuple  | 
3110 | 3102 |     if not isinstance(reps, list | tuple):  | 
3111 |  | -        reps_astensor = as_tensor_variable(reps)  | 
3112 |  | -        ndim_check = reps_astensor.ndim  | 
3113 |  | -        if reps_astensor.dtype not in discrete_dtypes:  | 
3114 |  | -            raise ValueError("elements of reps must be integer dtype")  | 
3115 |  | - | 
3116 |  | -        # The scalar/integer case  | 
3117 |  | -        if ndim_check == 0:  | 
3118 |  | -            reps = [reps]  | 
3119 |  | - | 
3120 |  | -        # The vector case  | 
3121 |  | -        elif ndim_check == 1:  | 
3122 |  | -            if ndim is None:  | 
 | 3103 | +        reps = as_tensor_variable(reps)  | 
 | 3104 | +        if reps.type.ndim == 0:  | 
 | 3105 | +            reps = (reps,)  | 
 | 3106 | +        elif reps.type.ndim == 1:  | 
 | 3107 | +            try:  | 
 | 3108 | +                reps = tuple(reps)  | 
 | 3109 | +            except ValueError:  | 
3123 | 3110 |                 raise ValueError(  | 
3124 |  | -                    "if reps is tensor.vector, you should specify the ndim"  | 
 | 3111 | +                    "Length of repetitions tensor cannot be determined. Use specify_shape to set the length."  | 
3125 | 3112 |                 )  | 
3126 |  | -            else:  | 
3127 |  | -                offset = ndim - reps.shape[0]  | 
3128 |  | - | 
3129 |  | -                # assert that reps.shape[0] does not exceed ndim  | 
3130 |  | -                offset = assert_op(offset, ge(offset, 0))  | 
 | 3113 | +        else:  | 
 | 3114 | +            raise ValueError(  | 
 | 3115 | +                f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}"  | 
 | 3116 | +            )  | 
3131 | 3117 | 
 
  | 
3132 |  | -                # if reps.ndim is less than _x.ndim, we pad the reps with  | 
3133 |  | -                # "1" so that reps will have the same ndim as _x.  | 
3134 |  | -                reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]  | 
3135 |  | -                reps = reps_  | 
 | 3118 | +    reps = [as_tensor_variable(rep) for rep in reps]  | 
 | 3119 | +    if not all(  | 
 | 3120 | +        rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps  | 
 | 3121 | +    ):  | 
 | 3122 | +        raise ValueError(  | 
 | 3123 | +            f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}"  | 
 | 3124 | +        )  | 
3136 | 3125 | 
 
  | 
3137 |  | -        # For others, raise an error  | 
3138 |  | -        else:  | 
3139 |  | -            raise ValueError("the dimension of reps should not exceed 1")  | 
3140 |  | -    else:  | 
3141 |  | -        if ndim is not None and len(reps) > ndim:  | 
3142 |  | -            raise ValueError("len(reps) should be equal or less than ndim")  | 
3143 |  | -        if not all(  | 
3144 |  | -            isinstance(r, int)  | 
3145 |  | -            or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes)  | 
3146 |  | -            for r in reps  | 
3147 |  | -        ):  | 
3148 |  | -            raise ValueError("elements of reps must be scalars of integer dtype")  | 
 | 3126 | +    len_reps = len(reps)  | 
 | 3127 | +    out_ndim = builtins.max(len_reps, A.type.ndim)  | 
 | 3128 | + | 
 | 3129 | +    # Pad reps on the left (if needed)  | 
 | 3130 | +    if len_reps < out_ndim:  | 
 | 3131 | +        reps = (*((1,) * (out_ndim - len_reps)), *reps)  | 
 | 3132 | + | 
 | 3133 | +    # Pad A's shape on the left (if needed)  | 
 | 3134 | +    elif A.type.ndim < out_ndim:  | 
 | 3135 | +        A = shape_padleft(A, out_ndim - A.type.ndim)  | 
 | 3136 | + | 
 | 3137 | +    # Expand every other dim of A and expand n-reps via Alloc  | 
 | 3138 | +    # A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])  | 
 | 3139 | +    A_shape = A.shape  | 
 | 3140 | +    interleaved_reps_shape = [  | 
 | 3141 | +        d for pair in zip(reps, A.shape, strict=True) for d in pair  | 
 | 3142 | +    ]  | 
 | 3143 | +    every_other_axis = tuple(range(0, out_ndim * 2, 2))  | 
 | 3144 | +    A_replicated = alloc(  | 
 | 3145 | +        expand_dims(A, every_other_axis),  | 
 | 3146 | +        *interleaved_reps_shape,  | 
 | 3147 | +    )  | 
3149 | 3148 | 
 
  | 
3150 |  | -    # If reps.ndim is less than _x.ndim, we pad the reps with  | 
3151 |  | -    # "1" so that reps will have the same ndim as _x  | 
3152 |  | -    reps = list(reps)  | 
3153 |  | -    if ndim is None:  | 
3154 |  | -        ndim = builtins.max(len(reps), _x.ndim)  | 
3155 |  | -    if len(reps) < ndim:  | 
3156 |  | -        reps = [1] * (ndim - len(reps)) + reps  | 
3157 |  | - | 
3158 |  | -    _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]  | 
3159 |  | -    alloc_shape = reps + _shape  | 
3160 |  | -    y = alloc(_x, *alloc_shape)  | 
3161 |  | -    shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)  | 
3162 |  | -    shuffle_ind = shuffle_ind.transpose().flatten()  | 
3163 |  | -    y = y.dimshuffle(*shuffle_ind)  | 
3164 |  | -    new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)]  | 
3165 |  | -    y = y.reshape(new_shapes)  | 
3166 |  | - | 
3167 |  | -    return y  | 
 | 3149 | +    # Combine replicate and original dimensions via reshape  | 
 | 3150 | +    # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])  | 
 | 3151 | +    tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True))  | 
 | 3152 | +    return A_replicated.reshape(tiled_shape)  | 
3168 | 3153 | 
 
  | 
3169 | 3154 | 
 
  | 
3170 | 3155 | class ARange(Op):  | 
 | 
0 commit comments