diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index fc85209236..42a117b690 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1615,34 +1615,66 @@ def expand(o: Apply) -> List[Apply]: ) -def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: - """Determine if `f_node` is in the graph given by `l_apply`. +def is_in_ancestors( + l_apply: Apply, + f_apply: Union[Apply, Collection[Apply]], + *, + known_dependent: Optional[Set[Apply]] = None, + known_independent: Optional[Set[Apply]] = None, + eager=True, +) -> bool: + """Determine if `f_apply` is in the graph given by (any of) `l_apply`. Parameters ---------- l_apply : Apply The node to walk. - f_apply : Apply + f_apply : Union[Apply, Collection[Apply]] The node to find in `l_apply`. + known_dependent: Optional[Set[Apply]] + Cache information about intermediate Applys that depend on f_apply + known_independent: Optional[Set[Apply]] + Cache information about intermediate Applys that do not depend on f_apply + eager: bool + return on first match (True) or traverse the whole graph (False) Returns ------- bool """ - computed = set() - todo = [l_apply] - while todo: - cur = todo.pop() - if cur.outputs[0] in computed: - continue - if all(i in computed or i.owner is None for i in cur.inputs): - computed.update(cur.outputs) - if cur is f_node: - return True + if known_dependent is None: + known_dependent = set() + if known_independent is None: + known_independent = set() + if not isinstance(f_apply, Collection): + f_apply = {f_apply} + if l_apply in known_dependent: + return True + elif l_apply in f_apply: + known_dependent.add(l_apply) + return True + else: + search = ( + is_in_ancestors( + inp.owner, + f_apply, + known_dependent=known_dependent, + known_independent=known_independent, + eager=eager, + ) + for inp in l_apply.inputs + if inp.owner + ) + if not eager: + dependent = any(list(search)) else: - todo.append(cur) - todo.extend(i.owner for i in cur.inputs if i.owner) + dependent = any(search) + if dependent: + known_dependent.add(l_apply) + return True + + known_independent.add(l_apply) return False diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index f9b4fed79a..b5f55c7f76 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -512,6 +512,48 @@ def test_is_in_ancestors(): assert is_in_ancestors(o2.owner, o1.owner) +def test_is_in_ancestors_complete(): + + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o0 = MyOp(r2, r2) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + o3 = MyOp(o2, o0) + o3.name = "o3" + dependent = set() + independent = set() + assert is_in_ancestors( + o3.owner, + o1.owner, + known_dependent=dependent, + known_independent=independent, + # o0 should not fall into independent with + # eager=True default + # because it is the second input in o3 + ) + assert o1.owner in dependent + assert o2.owner in dependent + assert o3.owner in dependent + assert o0.owner not in independent + dependent = set() + independent = set() + assert is_in_ancestors( + o3.owner, + o1.owner, + known_dependent=dependent, + known_independent=independent, + # o0 should not fall into independent with + eager=False + # because it is supposed to be the complete traverse + ) + assert o1.owner in dependent + assert o2.owner in dependent + assert o3.owner in dependent + assert o0.owner in independent + + @pytest.mark.xfail(reason="Not implemented") def test_io_connection_pattern(): raise AssertionError()