3333 alloc ,
3434 get_scalar_constant_value ,
3535 nonzero ,
36- scalar_from_tensor ,
36+ )
37+ from pytensor .tensor .basic import (
38+ constant as tensor_constant ,
3739)
3840from pytensor .tensor .blockwise import vectorize_node_fallback
3941from pytensor .tensor .elemwise import DimShuffle
@@ -296,13 +298,30 @@ def get_canonical_form_slice(
296298 """
297299 from pytensor .tensor import ge , lt , sign , switch
298300
301+ def undo_scalarization (x ):
302+ """Undo scalarization of a variable.
303+
304+ PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305+ But reasoning symbolically about the result of multiple indexing operations, we usually
306+ want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307+
308+ This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309+ """
310+ if isinstance (x , ScalarVariable ):
311+ if isinstance (x , ScalarConstant ):
312+ return tensor_constant (x .data , dtype = x .dtype )
313+ elif x .owner is not None and isinstance (x .owner .op , ScalarFromTensor ):
314+ return x .owner .inputs [0 ]
315+ return x
316+
299317 # Other non-slice types are the scalar indexing case
300318 if not isinstance (theslice , slice ):
319+ theslice = undo_scalarization (theslice )
301320 if isinstance (theslice , int | np .integer | ScalarVariable ) or (
302321 isinstance (theslice , TensorVariable ) and theslice .ndim == 0
303322 ):
304323 cano = switch (lt (theslice , 0 ), (theslice + length ), theslice )
305- return scalar_from_tensor ( cano ) , 1
324+ return cano , 1
306325 raise ValueError (f"Slice { theslice } is not a supported slice type." )
307326
308327 # At this point we have a slice object. Possibly with symbolic inputs.
@@ -312,7 +331,7 @@ def analyze(x):
312331 x_constant = as_index_literal (x )
313332 is_constant = True
314333 except NotScalarConstantError :
315- x_constant = x
334+ x_constant = undo_scalarization ( x )
316335 is_constant = False
317336 return x_constant , is_constant
318337
0 commit comments