Description
import pytensor
import pytensor.tensor as pt
x = pt.tensor("x", shape=(None, 100))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((2, 50,))}, rebuild_strict=False)
print(x.type) # TensorType(float64, (100,))
print(new_y.type) # TensorType(float64, (100,)) <- INCORRECT!
print(new_y.eval().shape) # (50,)
This happens because clone_with_new_inputs
doesn't try to rebuild the node when the inputs have the same types as before: https://github.com/ricardoV94/pytensor/blob/8606498fa92e86ec10f7357e4da228cfd86bced7/pytensor/graph/basic.py#L269-L280
This is incorrect as the output type can change when the input values change, not just their types, as happens with Alloc
.
I think we should always rebuild the node if strict=False