From b448a42f78e57a4581b4759df9554c37782b5b65 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 11:15:13 +0300 Subject: [PATCH 1/6] initial test passes --- pytensor/graph/basic.py | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index fc85209236..53d8f7745e 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1615,8 +1615,15 @@ def expand(o: Apply) -> List[Apply]: ) -def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: - """Determine if `f_node` is in the graph given by `l_apply`. +def is_in_ancestors( + l_apply: Apply, + f_apply: Union[Apply, Sequence[Apply]], + *, + known_dependent=None, + known_independent=None, + eager=True, +) -> bool: + """Determine if `f_apply` is in the graph given by `l_apply`. Parameters ---------- @@ -1630,19 +1637,36 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: bool """ - computed = set() - todo = [l_apply] - while todo: - cur = todo.pop() - if cur.outputs[0] in computed: - continue - if all(i in computed or i.owner is None for i in cur.inputs): - computed.update(cur.outputs) - if cur is f_node: - return True - else: - todo.append(cur) - todo.extend(i.owner for i in cur.inputs if i.owner) + if known_dependent is None: + known_dependent = set() + if known_independent is None: + known_independent = set() + if not isinstance(f_apply, Sequence): + f_apply = [f_apply] + if l_apply in known_dependent: + return True + elif l_apply in f_apply: + known_dependent.add(l_apply) + return True + else: + search = ( + is_in_ancestors( + inp.owner, + f_apply, + known_dependent=known_dependent, + known_independent=known_independent, + eager=eager + ) + for inp in l_apply.inputs + if inp.owner + ) + if not eager: + search = list(search) + if any(search): + known_dependent.add(l_apply) + return True + + known_independent.add(l_apply) return False From 7655686e8740525b68f8e1589831e73615c31674 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 11:26:47 +0300 Subject: [PATCH 2/6] extend is_in_ancestors to make the complete split --- pytensor/graph/basic.py | 14 ++++++++++---- tests/graph/test_basic.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 53d8f7745e..5fb95b52b7 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1619,8 +1619,8 @@ def is_in_ancestors( l_apply: Apply, f_apply: Union[Apply, Sequence[Apply]], *, - known_dependent=None, - known_independent=None, + known_dependent: Optional[Set[Apply]] = None, + known_independent: Optional[Set[Apply]] = None, eager=True, ) -> bool: """Determine if `f_apply` is in the graph given by `l_apply`. @@ -1629,8 +1629,14 @@ def is_in_ancestors( ---------- l_apply : Apply The node to walk. - f_apply : Apply + f_apply : Union[Apply, Sequence[Apply]] The node to find in `l_apply`. + known_dependent: Optional[Set[Apply]] + Cache information about intermediate Applys that depend on f_apply + known_independent: Optional[Set[Apply]] + Cache information about intermediate Applys that do not depend on f_apply + eager: bool + return on first match (True) or traverse the whole graph (False) Returns ------- @@ -1655,7 +1661,7 @@ def is_in_ancestors( f_apply, known_dependent=known_dependent, known_independent=known_independent, - eager=eager + eager=eager, ) for inp in l_apply.inputs if inp.owner diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index f9b4fed79a..78138ef25c 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -512,6 +512,42 @@ def test_is_in_ancestors(): assert is_in_ancestors(o2.owner, o1.owner) +def test_is_in_ancestors_complete(): + + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o0 = MyOp(r2, r2) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + o3 = MyOp(o2, o0) + o3.name = "o2" + dependent = set() + independent = set() + assert is_in_ancestors( + o3.owner, + o1.owner, + known_dependent=dependent, + known_independent=independent, + # o0 should not fall into independent with + # eager=True default + # because it is the second input in o3 + ) + assert o0.owner not in independent + dependent = set() + independent = set() + assert is_in_ancestors( + o3.owner, + o1.owner, + known_dependent=dependent, + known_independent=independent, + # o0 should not fall into independent with + eager=False + # because it is supposed to be the complete traverse + ) + assert o0.owner in independent + + @pytest.mark.xfail(reason="Not implemented") def test_io_connection_pattern(): raise AssertionError() From 7f90c07593181d80bb5a26a9a625497398bd99e9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 11:32:39 +0300 Subject: [PATCH 3/6] check every node is assigned dependent or independent --- tests/graph/test_basic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 78138ef25c..b5f55c7f76 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -521,7 +521,7 @@ def test_is_in_ancestors_complete(): o2 = MyOp(r3, o1) o2.name = "o2" o3 = MyOp(o2, o0) - o3.name = "o2" + o3.name = "o3" dependent = set() independent = set() assert is_in_ancestors( @@ -533,6 +533,9 @@ def test_is_in_ancestors_complete(): # eager=True default # because it is the second input in o3 ) + assert o1.owner in dependent + assert o2.owner in dependent + assert o3.owner in dependent assert o0.owner not in independent dependent = set() independent = set() @@ -545,6 +548,9 @@ def test_is_in_ancestors_complete(): eager=False # because it is supposed to be the complete traverse ) + assert o1.owner in dependent + assert o2.owner in dependent + assert o3.owner in dependent assert o0.owner in independent From cc57f614c08cd662d6c5796da5e24d2742c4794f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 11:40:02 +0300 Subject: [PATCH 4/6] fix mypy --- pytensor/graph/basic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 5fb95b52b7..f18a73046e 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1667,8 +1667,10 @@ def is_in_ancestors( if inp.owner ) if not eager: - search = list(search) - if any(search): + dependent = any(list(search)) + else: + dependent = any(search) + if dependent: known_dependent.add(l_apply) return True From 22d22d476d5ded4fe20136fc83db0d8fa8a72f1c Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 11:52:23 +0300 Subject: [PATCH 5/6] add docstring --- pytensor/graph/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index f18a73046e..45b243da9c 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1623,7 +1623,7 @@ def is_in_ancestors( known_independent: Optional[Set[Apply]] = None, eager=True, ) -> bool: - """Determine if `f_apply` is in the graph given by `l_apply`. + """Determine if `f_apply` is in the graph given by (any of) `l_apply`. Parameters ---------- From 04f6d1b4da31a75b247f0151860f2f688b314abd Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Nov 2022 17:29:21 +0300 Subject: [PATCH 6/6] less restrictive typing --- pytensor/graph/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 45b243da9c..42a117b690 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1617,7 +1617,7 @@ def expand(o: Apply) -> List[Apply]: def is_in_ancestors( l_apply: Apply, - f_apply: Union[Apply, Sequence[Apply]], + f_apply: Union[Apply, Collection[Apply]], *, known_dependent: Optional[Set[Apply]] = None, known_independent: Optional[Set[Apply]] = None, @@ -1629,7 +1629,7 @@ def is_in_ancestors( ---------- l_apply : Apply The node to walk. - f_apply : Union[Apply, Sequence[Apply]] + f_apply : Union[Apply, Collection[Apply]] The node to find in `l_apply`. known_dependent: Optional[Set[Apply]] Cache information about intermediate Applys that depend on f_apply @@ -1647,8 +1647,8 @@ def is_in_ancestors( known_dependent = set() if known_independent is None: known_independent = set() - if not isinstance(f_apply, Sequence): - f_apply = [f_apply] + if not isinstance(f_apply, Collection): + f_apply = {f_apply} if l_apply in known_dependent: return True elif l_apply in f_apply: