33from collections .abc import Callable , Iterable
44from itertools import chain , groupby
55from textwrap import dedent
6+ from typing import cast , overload
67
78import numpy as np
89
1920from pytensor .link .c .params_type import ParamsType
2021from pytensor .misc .safe_asarray import _asarray
2122from pytensor .printing import Printer , pprint , set_precedence
22- from pytensor .scalar .basic import ScalarConstant
23- from pytensor .tensor import _get_vector_length , as_tensor_variable , get_vector_length
23+ from pytensor .scalar .basic import ScalarConstant , ScalarVariable
24+ from pytensor .tensor import (
25+ TensorLike ,
26+ _get_vector_length ,
27+ as_tensor_variable ,
28+ get_vector_length ,
29+ )
2430from pytensor .tensor .basic import (
2531 ScalarFromTensor ,
2632 alloc ,
2733 get_underlying_scalar_constant_value ,
2834 nonzero ,
35+ scalar_from_tensor ,
2936)
3037from pytensor .tensor .blockwise import vectorize_node_fallback
3138from pytensor .tensor .elemwise import DimShuffle
5158 wscalar ,
5259 zscalar ,
5360)
54- from pytensor .tensor .type_other import NoneConst , NoneTypeT , SliceType , make_slice
55- from pytensor .tensor .variable import TensorVariable
61+ from pytensor .tensor .type_other import (
62+ NoneConst ,
63+ NoneTypeT ,
64+ SliceConstant ,
65+ SliceType ,
66+ make_slice ,
67+ )
68+ from pytensor .tensor .variable import TensorConstant , TensorVariable
5669
5770
5871_logger = logging .getLogger ("pytensor.tensor.subtensor" )
@@ -134,7 +147,7 @@ def convert_indices(indices, entry):
134147
135148
136149def as_index_constant (
137- a : slice | int | np .integer | Variable | None ,
150+ a : slice | int | np .integer | Variable | None | TensorLike ,
138151) -> Variable | slice | None :
139152 r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
140153
@@ -150,15 +163,41 @@ def as_index_constant(
150163 )
151164 elif isinstance (a , int | np .integer ):
152165 return ps .ScalarConstant (ps .int64 , a )
153- elif not isinstance (a , Variable ):
154- return as_tensor_variable (a )
155- else :
166+ elif isinstance (a , Variable ):
156167 return a
168+ return as_tensor_variable (a )
169+
170+
171+ @overload
172+ def as_index_literal (idx : int | np .integer ) -> int | np .integer : ...
173+
174+
175+ @overload
176+ def as_index_literal (idx : None ) -> None : ...
177+
178+
179+ @overload
180+ def as_index_literal (idx : slice | SliceConstant ) -> slice : ...
181+
182+
183+ @overload
184+ def as_index_literal (idx : ScalarConstant | TensorConstant ) -> int | np .integer : ...
185+
186+
187+ @overload
188+ def as_index_literal (idx : Variable ): ...
157189
158190
159191def as_index_literal (
160- idx : Variable | slice | None ,
161- ) -> int | slice | None :
192+ idx : None
193+ | int
194+ | np .integer
195+ | slice
196+ | SliceConstant
197+ | ScalarConstant
198+ | TensorConstant
199+ | Variable ,
200+ ) -> int | np .integer | slice | None :
162201 """Convert a symbolic index element to its Python equivalent.
163202
164203 This is like the inverse of `as_index_constant`
@@ -167,22 +206,8 @@ def as_index_literal(
167206 ------
168207 NotScalarConstantError
169208 """
170- if idx == np .newaxis or isinstance (getattr (idx , "type" , None ), NoneTypeT ):
171- return np .newaxis
172-
173- if isinstance (idx , Constant ):
174- return idx .data .item () if isinstance (idx , np .ndarray ) else idx .data
175-
176- if isinstance (idx , Variable ):
177- if (
178- isinstance (idx .type , ps .ScalarType )
179- and idx .owner
180- and isinstance (idx .owner .op , ScalarFromTensor )
181- ):
182- return as_index_literal (idx .owner .inputs [0 ])
183-
184- if isinstance (idx .type , SliceType ):
185- idx = slice (* idx .owner .inputs )
209+ if idx is None or isinstance (idx , int | np .integer ):
210+ return idx
186211
187212 if isinstance (idx , slice ):
188213 return slice (
@@ -191,17 +216,64 @@ def as_index_literal(
191216 as_index_literal (idx .step ),
192217 )
193218
219+ if not isinstance (idx , Variable ):
220+ raise TypeError (f"Not an index element: { idx } " )
221+
222+ if isinstance (idx .type , NoneTypeT ):
223+ return None
224+
225+ if isinstance (idx , ScalarConstant ):
226+ return cast (int , idx .data )
227+
228+ if (
229+ isinstance (idx .type , ps .ScalarType )
230+ and idx .owner
231+ and isinstance (idx .owner .op , ScalarFromTensor )
232+ ):
233+ return cast (int | np .integer , as_index_literal (idx .owner .inputs [0 ]))
234+
235+ if isinstance (idx , TensorConstant ):
236+ return cast (int , idx .data .item ())
237+
238+ if isinstance (idx , SliceConstant ):
239+ return cast (slice , idx .data )
240+
241+ if isinstance (idx .type , SliceType ):
242+ assert idx .owner is not None
243+ return slice (* map (as_index_literal , idx .owner .inputs ))
244+
245+ # Other kinds of variables are not supported
194246 raise NotScalarConstantError ()
195247
196248
197249def get_idx_list (inputs , idx_list ):
198250 return indices_from_subtensor (inputs [1 :], idx_list )
199251
200252
253+ @overload
254+ def get_canonical_form_slice (
255+ theslice : slice ,
256+ length : int | np .integer | ScalarVariable | TensorVariable ,
257+ ) -> tuple [slice , int | ScalarConstant ]: ...
258+
259+
260+ @overload
261+ def get_canonical_form_slice (
262+ theslice : int | np .integer | ScalarVariable | TensorVariable ,
263+ length : int | np .integer | ScalarVariable | TensorVariable ,
264+ ) -> tuple [ScalarVariable , int ]: ...
265+
266+
201267def get_canonical_form_slice (
202- theslice : slice | Variable , length : Variable
203- ) -> tuple [Variable , int ]:
204- """Convert slices to canonical form.
268+ theslice : slice | int | np .integer | ScalarVariable | TensorVariable ,
269+ length : int | np .integer | ScalarVariable | TensorVariable ,
270+ ) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
271+ """Convert indices or slices to canonical form.
272+
273+ Scalar integer indices or python Slices with Scalar/None attributes
274+ used in basic Subtensor Ops are supported.
275+ Symbolic slices (of SliceType) or vector indices
276+ used in advanced Subtensor Ops are not supported.
205277
206278 Given a slice [start:stop:step] transform it into a canonical form
207279 that respects the conventions imposed by python and numpy.
@@ -210,18 +282,28 @@ def get_canonical_form_slice(
210282 in which 0 <= start <= stop <= length and step > 0, and a flag which says
211283 if the resulting set of numbers needs to be reversed or not.
212284
285+ Given a scalar index `idx` that may or not be negative, convert it to
286+ a certainly positive form `idx if idx >= 0 else length + idx`.
287+
288+ Returns
289+ -------
290+ slc
291+ Canonical form slice or scalar variable.
292+ direction
293+ Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
213294 """
214295 from pytensor .tensor import ge , lt , sign , switch
215296
297+ # Other non-slice types are the scalar indexing case
216298 if not isinstance (theslice , slice ):
217- try :
218- value = as_index_literal (theslice )
219- except NotScalarConstantError :
220- value = theslice
221-
222- value = switch ( lt ( value , 0 ), ( value + length ), value )
299+ if isinstance ( theslice , int | np . integer | ScalarVariable ) or (
300+ isinstance (theslice , TensorVariable ) and theslice . ndim == 0
301+ ) :
302+ cano = switch ( lt ( theslice , 0 ), ( theslice + length ), theslice )
303+ return scalar_from_tensor ( cano ), 1
304+ raise ValueError ( f"Slice { theslice } is not a supported slice type." )
223305
224- return value , 1
306+ # At this point we have a slice object. Possibly with symbolic inputs.
225307
226308 def analyze (x ):
227309 try :
@@ -243,6 +325,7 @@ def analyze(x):
243325 and is_step_constant
244326 and is_length_constant
245327 ):
328+ assert isinstance (length , int )
246329 _start , _stop , _step = slice (start , stop , step ).indices (length )
247330 if _start <= _stop and _step >= 1 :
248331 return slice (_start , _stop , _step ), 1
@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
29173000 return a [full_indices ]
29183001
29193002
2920- @_get_vector_length .register (Subtensor )
3003+ @_get_vector_length .register (Subtensor ) # type: ignore
29213004def _get_vector_length_Subtensor (op , var ):
29223005 # If we take a slice, we know how many elements it will result in
29233006 # TODO: We can cover more `*Subtensor` cases.
0 commit comments