-
Notifications
You must be signed in to change notification settings - Fork 143
Closed
Description
Description
Right now we need to add specify_shape
if we want to squeeze a dimension that the user (but not PyTensor) known to be length 1. This complicates the graph slightly, and I don't see a reason for it. The only thing DimShuffle
needs to know is the number of dimensions of the input which is never ambiguous. Then an missing in the pattern
is a drop.
pytensor/pytensor/tensor/extra_ops.py
Lines 602 to 607 in ee4d4f7
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable. | |
# We add a `specify_broadcastable` instead of raising. | |
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]] | |
_x = specify_broadcastable(_x, *non_broadcastable_axis) | |
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) |
We should check nothing in the implementation fails if something was meant to be dropped but was not length 1 at runtime. If nothing fails, we can simplify DimShuffle and get rid of the useless SpecifyShape.