|  | 
| 35 | 35 |     nonzero, | 
| 36 | 36 |     scalar_from_tensor, | 
| 37 | 37 | ) | 
|  | 38 | +from pytensor.tensor.basic import ( | 
|  | 39 | +    constant as tensor_constant, | 
|  | 40 | +) | 
| 38 | 41 | from pytensor.tensor.blockwise import vectorize_node_fallback | 
| 39 | 42 | from pytensor.tensor.elemwise import DimShuffle | 
| 40 | 43 | from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError | 
| @@ -252,6 +255,23 @@ def get_idx_list(inputs, idx_list): | 
| 252 | 255 |     return indices_from_subtensor(inputs[1:], idx_list) | 
| 253 | 256 | 
 | 
| 254 | 257 | 
 | 
|  | 258 | +def undo_scalarization(x): | 
|  | 259 | +    """Undo scalarization of a variable. | 
|  | 260 | +
 | 
|  | 261 | +    PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. | 
|  | 262 | +    When reason symbolically about the result of multiple indexing operations, we usually | 
|  | 263 | +    want to work on TensorVariables, since rewrites work on those and not ScalarVariables. | 
|  | 264 | +
 | 
|  | 265 | +    This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. | 
|  | 266 | +    """ | 
|  | 267 | +    if isinstance(x, ScalarVariable): | 
|  | 268 | +        if isinstance(x, ScalarConstant): | 
|  | 269 | +            return tensor_constant(x.data, dtype=x.dtype) | 
|  | 270 | +        elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): | 
|  | 271 | +            return x.owner.inputs[0] | 
|  | 272 | +    return x | 
|  | 273 | + | 
|  | 274 | + | 
| 255 | 275 | @overload | 
| 256 | 276 | def get_canonical_form_slice( | 
| 257 | 277 |     theslice: slice, | 
| @@ -298,6 +318,7 @@ def get_canonical_form_slice( | 
| 298 | 318 | 
 | 
| 299 | 319 |     # Other non-slice types are the scalar indexing case | 
| 300 | 320 |     if not isinstance(theslice, slice): | 
|  | 321 | +        theslice = undo_scalarization(theslice) | 
| 301 | 322 |         if isinstance(theslice, int | np.integer | ScalarVariable) or ( | 
| 302 | 323 |             isinstance(theslice, TensorVariable) and theslice.ndim == 0 | 
| 303 | 324 |         ): | 
| @@ -381,6 +402,7 @@ def analyze(x): | 
| 381 | 402 |         elif is_stop_length: | 
| 382 | 403 |             # start:length:1 | 
| 383 | 404 |             if is_start_constant and start >= 0: | 
|  | 405 | +                length = undo_scalarization(length) | 
| 384 | 406 |                 return slice(switch(lt(start, length), start, length), length, 1), 1 | 
| 385 | 407 |             start_plus_len = start + length | 
| 386 | 408 |             start = switch( | 
|  | 
0 commit comments