3333 alloc ,
3434 get_scalar_constant_value ,
3535 nonzero ,
36- scalar_from_tensor ,
36+ )
37+ from pytensor .tensor .basic import (
38+ constant as tensor_constant ,
3739)
3840from pytensor .tensor .blockwise import vectorize_node_fallback
3941from pytensor .tensor .elemwise import DimShuffle
@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list):
256258def get_canonical_form_slice (
257259 theslice : slice ,
258260 length : int | np .integer | ScalarVariable | TensorVariable ,
259- ) -> tuple [slice , int | ScalarConstant ]: ...
261+ ) -> tuple [slice , int | TensorVariable ]: ...
260262
261263
262264@overload
263265def get_canonical_form_slice (
264266 theslice : int | np .integer | ScalarVariable | TensorVariable ,
265267 length : int | np .integer | ScalarVariable | TensorVariable ,
266- ) -> tuple [ScalarVariable , int ]: ...
268+ ) -> tuple [TensorVariable , int ]: ...
267269
268270
269271def get_canonical_form_slice (
270272 theslice : slice | int | np .integer | ScalarVariable | TensorVariable ,
271273 length : int | np .integer | ScalarVariable | TensorVariable ,
272- ) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
274+ ) -> tuple [slice | TensorVariable , int | TensorVariable ]:
273275 """Convert indices or slices to canonical form.
274276
275277 Scalar integer indices or python Slices with Scalar/None attributes
@@ -296,30 +298,56 @@ def get_canonical_form_slice(
296298 """
297299 from pytensor .tensor import ge , lt , sign , switch
298300
299- # Other non-slice types are the scalar indexing case
300- if not isinstance (theslice , slice ):
301- if isinstance (theslice , int | np .integer | ScalarVariable ) or (
302- isinstance (theslice , TensorVariable ) and theslice .ndim == 0
303- ):
304- cano = switch (lt (theslice , 0 ), (theslice + length ), theslice )
305- return scalar_from_tensor (cano ), 1
306- raise ValueError (f"Slice { theslice } is not a supported slice type." )
301+ def undo_scalarization (x ):
302+ """Undo scalarization of a variable.
307303
308- # At this point we have a slice object. Possibly with symbolic inputs.
304+ PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305+ But reasoning symbolically about the result of multiple indexing operations, we usually
306+ want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307+
308+ This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309+ """
310+ if isinstance (x , ScalarVariable ):
311+ if isinstance (x , ScalarConstant ):
312+ return tensor_constant (x .data , dtype = x .dtype )
313+ elif x .owner is not None and isinstance (x .owner .op , ScalarFromTensor ):
314+ return x .owner .inputs [0 ]
315+ else :
316+ return as_tensor_variable (x )
317+ return x
309318
310319 def analyze (x ):
311320 try :
312321 x_constant = as_index_literal (x )
313322 is_constant = True
314323 except NotScalarConstantError :
315- x_constant = x
324+ x_constant = undo_scalarization ( x )
316325 is_constant = False
317326 return x_constant , is_constant
318327
328+ length , is_length_constant = analyze (length )
329+
330+ # Other non-slice types are the scalar indexing case
331+ if not isinstance (theslice , slice ):
332+ if not (
333+ isinstance (theslice , int | np .integer | ScalarVariable )
334+ or (isinstance (theslice , TensorVariable ) and theslice .ndim == 0 )
335+ ):
336+ raise ValueError (f"Slice { theslice } is not a supported slice type." )
337+
338+ idx , is_index_constant = analyze (theslice )
339+ if is_index_constant :
340+ if idx >= 0 :
341+ return idx , 1
342+ else :
343+ return idx + length , 1
344+ else :
345+ return switch (lt (idx , 0 ), idx + length , idx ), 1
346+
347+ # At this point we have a slice object. Possibly with symbolic inputs.
319348 start , is_start_constant = analyze (theslice .start )
320349 stop , is_stop_constant = analyze (theslice .stop )
321350 step , is_step_constant = analyze (theslice .step )
322- length , is_length_constant = analyze (length )
323351
324352 if (
325353 is_start_constant
0 commit comments