Skip to content

local_useless_slice wrong for negative steps #1054

@ricardoV94

Description

@ricardoV94

Apparently in python, the following are not equivalent:

print(
    ["a", "b"][0:-3:-1],
    ["a", "b"][None:-3:-1],
)  # ['a'] ['b', 'a']

So the local_useless_slice which was updated in fa0ab9d is not correct :(

@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
@register_stabilize
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
Also, rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]
if not idxs:
return [node.inputs[0]]
new_idxs = list(idxs)
change_flag = False
last_useful_idx = -1
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
last_useful_idx = dim
continue
if s == slice(None):
continue
start = s.start
stop = s.stop
step = s.step
if (
start is not None
and extract_constant(start, only_process_constants=True) == 0
):
change_flag = True
start = None
if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
):
change_flag = True
stop = None
if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None
if not (start is None and stop is None and step is None):
last_useful_idx = dim
new_idxs[dim] = slice(start, stop, step)
if change_flag or ((last_useful_idx + 1) < len(idxs)):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]

This is behind the last failure noted in pymc-devs/pymc#7554 (comment)

MRE:

import pytensor.tensor as pt

x = pt.vector("x", dtype=int)
y = x[0:-3:-1]

assert list(y.eval({x: [0, 2]})) == [0, 2][0:-3:-1]  # AssertionError

The defaults Start and Stop are different when we have negative steps, it changes from 0: len(y) to -1:-len(y)-1. We need to account for that in the rewrite!

CC @Dhruvanshu-Joshi and @tomicapretto

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions