Skip to content

Rewrite to remove useless transpose followed by reduction #1006

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt
x = pt.matrix("x")
y = pt.sum(x.T)
fn = pytensor.function([x], y)
fn.dprint()
Sum{axes=None} [id A] 1
 └─ Transpose{axes=[1, 0]} [id B] 'x.T' 0
    └─ x [id C]

If we can remove a transpose (dimshuffle) without affecting the output (sometimes by changing the reduction axes, sometimes without having to do anything), we could rewrite it away. That allows more succinct graphs and more extensive canonicalizaiton

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions