-
Notifications
You must be signed in to change notification settings - Fork 145
Closed
Labels
Description
Apparently in python, the following are not equivalent:
print(
["a", "b"][0:-3:-1],
["a", "b"][None:-3:-1],
) # ['a'] ['b', 'a']
So the local_useless_slice
which was updated in fa0ab9d is not correct :(
pytensor/pytensor/tensor/rewriting/subtensor.py
Lines 337 to 406 in b8dbd4c
@register_infer_shape | |
@register_useless | |
@register_canonicalize | |
@register_specialize | |
@register_stabilize | |
@node_rewriter([Subtensor]) | |
def local_useless_slice(fgraph, node): | |
""" | |
Remove Subtensor of the form: | |
1. X[0, :] -> X[0] | |
2. X[:] -> X | |
Also, rewrite Subtensor of the form: | |
X[0:7:1] -> X[None:None:None] | |
where X is a vector of length 7 | |
""" | |
idxs = get_idx_list(node.inputs, node.op.idx_list) | |
x = node.inputs[0] | |
if not idxs: | |
return [node.inputs[0]] | |
new_idxs = list(idxs) | |
change_flag = False | |
last_useful_idx = -1 | |
for dim, s in enumerate(new_idxs): | |
if not isinstance(s, slice): | |
last_useful_idx = dim | |
continue | |
if s == slice(None): | |
continue | |
start = s.start | |
stop = s.stop | |
step = s.step | |
if ( | |
start is not None | |
and extract_constant(start, only_process_constants=True) == 0 | |
): | |
change_flag = True | |
start = None | |
if ( | |
stop is not None | |
and x.type.shape[dim] is not None | |
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] | |
): | |
change_flag = True | |
stop = None | |
if ( | |
step is not None | |
and extract_constant(step, only_process_constants=True) == 1 | |
): | |
change_flag = True | |
step = None | |
if not (start is None and stop is None and step is None): | |
last_useful_idx = dim | |
new_idxs[dim] = slice(start, stop, step) | |
if change_flag or ((last_useful_idx + 1) < len(idxs)): | |
out = x[tuple(new_idxs[: last_useful_idx + 1])] | |
# Copy over previous output stacktrace | |
copy_stack_trace(node.outputs, out) | |
return [out] |
This is behind the last failure noted in pymc-devs/pymc#7554 (comment)
MRE:
import pytensor.tensor as pt
x = pt.vector("x", dtype=int)
y = x[0:-3:-1]
assert list(y.eval({x: [0, 2]})) == [0, 2][0:-3:-1] # AssertionError
The defaults Start and Stop are different when we have negative steps, it changes from 0: len(y)
to -1:-len(y)-1
. We need to account for that in the rewrite!
CC @Dhruvanshu-Joshi and @tomicapretto