Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:


@singledispatch
def _as_symbolic(x, **kwargs) -> Variable:
def _as_symbolic(x: Any, **kwargs) -> Variable:
from pytensor.tensor import as_tensor_variable

return as_tensor_variable(x, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,8 +1302,8 @@ def clone_node_and_cache(


def clone_get_equiv(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
inputs: Iterable[Variable],
outputs: Reversible[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections.abc import Sequence
from copy import copy
from typing import cast
from typing import Any, cast

import numpy as np

Expand Down Expand Up @@ -218,6 +218,7 @@ def _infer_shape(

from pytensor.tensor.extra_ops import broadcast_shape_iter

supp_shape: tuple[Any]
if self.ndim_supp == 0:
supp_shape = ()
else:
Expand Down
4 changes: 3 additions & 1 deletion pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def explicit_expand_dims(
return new_params


def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
def compute_batch_shape(
params: Sequence[TensorVariable], ndims_params: Sequence[int]
) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params)
batch_params = [
param[(..., *(0,) * core_ndim)]
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ def c_code_cache_version(self):
_shape = Shape()


def shape(x: np.ndarray | Number | Variable) -> Variable:
def shape(x: np.ndarray | Number | Variable) -> TensorVariable:
"""Return the shape of `x`."""
if not isinstance(x, Variable):
# The following is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]

return cast(Variable, _shape(x))
return cast(TensorVariable, _shape(x))


@_get_vector_length.register(Shape) # type: ignore
Expand Down Expand Up @@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],) # type: ignore
res += (symbolic_shape[i],)

return res

Expand Down
157 changes: 120 additions & 37 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Iterable
from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload

import numpy as np

Expand All @@ -19,13 +20,19 @@
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
TensorLike,
_get_vector_length,
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
nonzero,
scalar_from_tensor,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
Expand All @@ -51,8 +58,14 @@
wscalar,
zscalar,
)
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
from pytensor.tensor.variable import TensorVariable
from pytensor.tensor.type_other import (
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable


_logger = logging.getLogger("pytensor.tensor.subtensor")
Expand Down Expand Up @@ -134,7 +147,7 @@ def convert_indices(indices, entry):


def as_index_constant(
a: slice | int | np.integer | Variable | None,
a: slice | int | np.integer | Variable | None | TensorLike,
) -> Variable | slice | None:
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.

Expand All @@ -150,15 +163,41 @@ def as_index_constant(
)
elif isinstance(a, int | np.integer):
return ps.ScalarConstant(ps.int64, a)
elif not isinstance(a, Variable):
return as_tensor_variable(a)
else:
elif isinstance(a, Variable):
return a
return as_tensor_variable(a)


@overload
def as_index_literal(idx: int | np.integer) -> int | np.integer: ...


@overload
def as_index_literal(idx: None) -> None: ...


@overload
def as_index_literal(idx: slice | SliceConstant) -> slice: ...


@overload
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...


@overload
def as_index_literal(idx: Variable): ...


def as_index_literal(
idx: Variable | slice | None,
) -> int | slice | None:
idx: None
| int
| np.integer
| slice
| SliceConstant
| ScalarConstant
| TensorConstant
| Variable,
) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent.

This is like the inverse of `as_index_constant`
Expand All @@ -167,22 +206,8 @@ def as_index_literal(
------
NotScalarConstantError
"""
if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT):
return np.newaxis

if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data

if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])

if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs)
if idx is None or isinstance(idx, int | np.integer):
return idx

if isinstance(idx, slice):
return slice(
Expand All @@ -191,17 +216,64 @@ def as_index_literal(
as_index_literal(idx.step),
)

if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}")

if isinstance(idx.type, NoneTypeT):
return None

if isinstance(idx, ScalarConstant):
return cast(int, idx.data)

if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return cast(int | np.integer, as_index_literal(idx.owner.inputs[0]))

if isinstance(idx, TensorConstant):
return cast(int, idx.data.item())

if isinstance(idx, SliceConstant):
return cast(slice, idx.data)

if isinstance(idx.type, SliceType):
assert idx.owner is not None
return slice(*map(as_index_literal, idx.owner.inputs))

# Other kinds of variables are not supported
raise NotScalarConstantError()


def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)


@overload
def get_canonical_form_slice(
theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | ScalarConstant]: ...


@overload
def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[ScalarVariable, int]: ...


def get_canonical_form_slice(
theslice: slice | Variable, length: Variable
) -> tuple[Variable, int]:
"""Convert slices to canonical form.
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
"""Convert indices or slices to canonical form.

Scalar integer indices or python Slices with Scalar/None attributes
used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.

Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
Expand All @@ -210,18 +282,28 @@ def get_canonical_form_slice(
in which 0 <= start <= stop <= length and step > 0, and a flag which says
if the resulting set of numbers needs to be reversed or not.

Given a scalar index `idx` that may or not be negative, convert it to
a certainly positive form `idx if idx >= 0 else length + idx`.

Returns
-------
slc
Canonical form slice or scalar variable.
direction
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
"""
from pytensor.tensor import ge, lt, sign, switch

# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
try:
value = as_index_literal(theslice)
except NotScalarConstantError:
value = theslice

value = switch(lt(value, 0), (value + length), value)
if isinstance(theslice, int | np.integer | ScalarVariable) or (
isinstance(theslice, TensorVariable) and theslice.ndim == 0
):
cano = switch(lt(theslice, 0), (theslice + length), theslice)
return scalar_from_tensor(cano), 1
raise ValueError(f"Slice {theslice} is not a supported slice type.")

return value, 1
# At this point we have a slice object. Possibly with symbolic inputs.

def analyze(x):
try:
Expand All @@ -243,6 +325,7 @@ def analyze(x):
and is_step_constant
and is_length_constant
):
assert isinstance(length, int)
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
Expand Down Expand Up @@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices]


@_get_vector_length.register(Subtensor)
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def clone(
shape = self.shape
return type(self)(dtype, shape, name=self.name)

def filter(self, data, strict=False, allow_downcast=None):
def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
"""Convert `data` to something which can be associated to a `TensorVariable`.

This function is not meant to be called in user code. It is for
Expand Down
1 change: 0 additions & 1 deletion scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py
pytensor/tensor/type_other.py
pytensor/tensor/variable.py
Loading