Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,17 @@ def update_shape(self, r, other_r):
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in ancestors([other_shape[i]]):
elif any(
(
r_shape[i] == anc
or (
anc.owner
and isinstance(anc.owner.op, Shape)
and anc.owner.inputs[0] == r
)
)
for anc in ancestors([other_shape[i]])
):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
Expand Down
21 changes: 20 additions & 1 deletion tests/tensor/rewriting/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.type import Type
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.basic import alloc, as_tensor_variable
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import add, exp, maximum
from pytensor.tensor.rewriting.basic import register_specialize
Expand Down Expand Up @@ -239,6 +239,25 @@ def test_no_shapeopt(self):
# FIXME: This is not a good test.
f([[1, 2], [2, 3]])

def test_shape_of_useless_alloc(self):
"""Test that local_shape_to_shape_i does not create circular graph.

Regression test for #565
"""
alpha = vector(shape=(None,), dtype="float64")
channel = vector(shape=(None,), dtype="float64")

broadcast_channel = alloc(
channel,
maximum(
shape(alpha)[0],
shape(channel)[0],
),
)
out = shape(broadcast_channel)
fn = function([alpha, channel], out)
assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,)


class TestReshape:
def setup_method(self):
Expand Down