Skip to content

DimShuffle should be happy to drop dims if they have length 1 at runtime #914

@ricardoV94

Description

@ricardoV94

Description

Right now we need to add specify_shape if we want to squeeze a dimension that the user (but not PyTensor) known to be length 1. This complicates the graph slightly, and I don't see a reason for it. The only thing DimShuffle needs to know is the number of dimensions of the input which is never ambiguous. Then an missing in the pattern is a drop.

# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])

We should check nothing in the implementation fails if something was meant to be dropped but was not length 1 at runtime. If nothing fails, we can simplify DimShuffle and get rid of the useless SpecifyShape.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions