Skip to content

Commit d1ad9df

Browse files
committed
Simplify implementation of tile
Deprecate obscure ndim kwarg
1 parent 884dee9 commit d1ad9df

File tree

2 files changed

+147
-236
lines changed

2 files changed

+147
-236
lines changed

pytensor/tensor/basic.py

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Sequence
1111
from functools import partial
1212
from numbers import Number
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Union
1414
from typing import cast as type_cast
1515

1616
import numpy as np
@@ -33,7 +33,7 @@
3333
from pytensor.link.c.op import COp
3434
from pytensor.link.c.params_type import ParamsType
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
36-
from pytensor.raise_op import CheckAndRaise, assert_op
36+
from pytensor.raise_op import CheckAndRaise
3737
from pytensor.scalar import int32
3838
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
3939
from pytensor.tensor import (
@@ -3084,87 +3084,72 @@ def flatten(x, ndim=1):
30843084
return x_reshaped
30853085

30863086

3087-
def tile(x, reps, ndim=None):
3087+
def tile(
3088+
A: "TensorLike", reps: Union[Sequence[int, "TensorLike"], "TensorLike"]
3089+
) -> TensorVariable:
30883090
"""
3089-
Tile input array `x` according to `reps`.
3091+
Tile input array `A` according to `reps`.
30903092
30913093
See the docstring of `numpy.tile` for details.
30923094
3093-
'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
3094-
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
3095-
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
3096-
3097-
ndim is the number of the dimensions of the output, if it is provided, ndim
3098-
should be equal or larger than x.ndim and len(reps), otherwise, we will use
3099-
max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
3100-
be provided.
3101-
3095+
If `reps` is a PyTensor vector, it's length must be statically known.
3096+
You can use `specify_shape` to set the length.
31023097
"""
3103-
from pytensor.tensor.math import ge
31043098

3105-
_x = as_tensor_variable(x)
3106-
if ndim is not None and ndim < _x.ndim:
3107-
raise ValueError("ndim should be equal or larger than _x.ndim")
3099+
A = as_tensor_variable(A)
31083100

3109-
# If reps is a scalar, integer or vector, we convert it to a list.
3101+
# Convert symbolic reps to a tuple
31103102
if not isinstance(reps, list | tuple):
3111-
reps_astensor = as_tensor_variable(reps)
3112-
ndim_check = reps_astensor.ndim
3113-
if reps_astensor.dtype not in discrete_dtypes:
3114-
raise ValueError("elements of reps must be integer dtype")
3115-
3116-
# The scalar/integer case
3117-
if ndim_check == 0:
3118-
reps = [reps]
3119-
3120-
# The vector case
3121-
elif ndim_check == 1:
3122-
if ndim is None:
3103+
reps = as_tensor_variable(reps)
3104+
if reps.type.ndim == 0:
3105+
reps = (reps,)
3106+
elif reps.type.ndim == 1:
3107+
try:
3108+
reps = tuple(reps)
3109+
except ValueError:
31233110
raise ValueError(
3124-
"if reps is tensor.vector, you should specify the ndim"
3111+
"Length of repetitions tensor cannot be determined. Use specify_shape to set the length."
31253112
)
3126-
else:
3127-
offset = ndim - reps.shape[0]
3128-
3129-
# assert that reps.shape[0] does not exceed ndim
3130-
offset = assert_op(offset, ge(offset, 0))
3113+
else:
3114+
raise ValueError(
3115+
f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}"
3116+
)
31313117

3132-
# if reps.ndim is less than _x.ndim, we pad the reps with
3133-
# "1" so that reps will have the same ndim as _x.
3134-
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
3135-
reps = reps_
3118+
reps = [as_tensor_variable(rep) for rep in reps]
3119+
if not all(
3120+
rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps
3121+
):
3122+
raise ValueError(
3123+
f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}"
3124+
)
31363125

3137-
# For others, raise an error
3138-
else:
3139-
raise ValueError("the dimension of reps should not exceed 1")
3140-
else:
3141-
if ndim is not None and len(reps) > ndim:
3142-
raise ValueError("len(reps) should be equal or less than ndim")
3143-
if not all(
3144-
isinstance(r, int)
3145-
or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes)
3146-
for r in reps
3147-
):
3148-
raise ValueError("elements of reps must be scalars of integer dtype")
3126+
len_reps = len(reps)
3127+
out_ndim = builtins.max(len_reps, A.type.ndim)
3128+
3129+
# Pad reps on the left (if needed)
3130+
if len_reps < out_ndim:
3131+
reps = (*((1,) * (out_ndim - len_reps)), *reps)
3132+
3133+
# Pad A's shape on the left (if needed)
3134+
elif A.type.ndim < out_ndim:
3135+
A = shape_padleft(A, out_ndim - A.type.ndim)
3136+
3137+
# Expand every other dim of A and expand n-reps via Alloc
3138+
# A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])
3139+
A_shape = A.shape
3140+
interleaved_reps_shape = [
3141+
d for pair in zip(reps, A.shape, strict=True) for d in pair
3142+
]
3143+
every_other_axis = tuple(range(0, out_ndim * 2, 2))
3144+
A_replicated = alloc(
3145+
expand_dims(A, every_other_axis),
3146+
*interleaved_reps_shape,
3147+
)
31493148

3150-
# If reps.ndim is less than _x.ndim, we pad the reps with
3151-
# "1" so that reps will have the same ndim as _x
3152-
reps = list(reps)
3153-
if ndim is None:
3154-
ndim = builtins.max(len(reps), _x.ndim)
3155-
if len(reps) < ndim:
3156-
reps = [1] * (ndim - len(reps)) + reps
3157-
3158-
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
3159-
alloc_shape = reps + _shape
3160-
y = alloc(_x, *alloc_shape)
3161-
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
3162-
shuffle_ind = shuffle_ind.transpose().flatten()
3163-
y = y.dimshuffle(*shuffle_ind)
3164-
new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)]
3165-
y = y.reshape(new_shapes)
3166-
3167-
return y
3149+
# Combine replicate and original dimensions via reshape
3150+
# A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])
3151+
tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True))
3152+
return A_replicated.reshape(tiled_shape)
31683153

31693154

31703155
class ARange(Op):

0 commit comments

Comments
 (0)