Skip to content

clone_replace doesn't try to rebuild nodes when input types don't change #253

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(None, 100))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((2, 50,))}, rebuild_strict=False)

print(x.type)  # TensorType(float64, (100,))
print(new_y.type)  # TensorType(float64, (100,)) <- INCORRECT!
print(new_y.eval().shape)  # (50,)

This happens because clone_with_new_inputs doesn't try to rebuild the node when the inputs have the same types as before: https://github.com/ricardoV94/pytensor/blob/8606498fa92e86ec10f7357e4da228cfd86bced7/pytensor/graph/basic.py#L269-L280

This is incorrect as the output type can change when the input values change, not just their types, as happens with Alloc.

I think we should always rebuild the node if strict=False

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions