Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,34 +1615,66 @@ def expand(o: Apply) -> List[Apply]:
)


def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
"""Determine if `f_node` is in the graph given by `l_apply`.
def is_in_ancestors(
l_apply: Apply,
f_apply: Union[Apply, Collection[Apply]],
*,
known_dependent: Optional[Set[Apply]] = None,
known_independent: Optional[Set[Apply]] = None,
eager=True,
) -> bool:
"""Determine if `f_apply` is in the graph given by (any of) `l_apply`.

Parameters
----------
l_apply : Apply
The node to walk.
f_apply : Apply
f_apply : Union[Apply, Collection[Apply]]
The node to find in `l_apply`.
known_dependent: Optional[Set[Apply]]
Cache information about intermediate Applys that depend on f_apply
known_independent: Optional[Set[Apply]]
Cache information about intermediate Applys that do not depend on f_apply
eager: bool
return on first match (True) or traverse the whole graph (False)

Returns
-------
bool

"""
computed = set()
todo = [l_apply]
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur is f_node:
return True
if known_dependent is None:
known_dependent = set()
if known_independent is None:
known_independent = set()
if not isinstance(f_apply, Collection):
f_apply = {f_apply}
if l_apply in known_dependent:
return True
elif l_apply in f_apply:
known_dependent.add(l_apply)
return True
else:
search = (
is_in_ancestors(
inp.owner,
f_apply,
known_dependent=known_dependent,
known_independent=known_independent,
eager=eager,
)
for inp in l_apply.inputs
if inp.owner
)
if not eager:
dependent = any(list(search))
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
dependent = any(search)
if dependent:
known_dependent.add(l_apply)
return True

known_independent.add(l_apply)
return False


Expand Down
42 changes: 42 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,48 @@ def test_is_in_ancestors():
assert is_in_ancestors(o2.owner, o1.owner)


def test_is_in_ancestors_complete():

r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o0 = MyOp(r2, r2)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
o3 = MyOp(o2, o0)
o3.name = "o3"
dependent = set()
independent = set()
assert is_in_ancestors(
o3.owner,
o1.owner,
known_dependent=dependent,
known_independent=independent,
# o0 should not fall into independent with
# eager=True default
# because it is the second input in o3
)
assert o1.owner in dependent
assert o2.owner in dependent
assert o3.owner in dependent
assert o0.owner not in independent
dependent = set()
independent = set()
assert is_in_ancestors(
o3.owner,
o1.owner,
known_dependent=dependent,
known_independent=independent,
# o0 should not fall into independent with
eager=False
# because it is supposed to be the complete traverse
)
assert o1.owner in dependent
assert o2.owner in dependent
assert o3.owner in dependent
assert o0.owner in independent


@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
Expand Down