-
Notifications
You must be signed in to change notification settings - Fork 143
Graph replace #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Graph replace #66
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
8e86d93
move clone_replace to a separate file
ferrine 49f1898
refactor is_in_ancestors to support multiple inputs
ferrine 8d9d75e
add variable_depends_on
ferrine db551b0
add truncated_graph_inputs function
ferrine 7b4cb73
add graph_replace function
ferrine File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -970,6 +970,136 @@ def applys_between( | |
) | ||
|
||
|
||
def truncated_graph_inputs( | ||
outputs: Sequence[Variable], | ||
ancestors_to_include: Optional[Collection[Variable]] = None, | ||
) -> List[Variable]: | ||
"""Get the truncate graph inputs. | ||
|
||
Unlike :func:`graph_inputs` this function will return | ||
the closest nodes to outputs that do not depend on | ||
``ancestors_to_include``. So given all the returned | ||
variables provided there is no missing node to | ||
compute the output and all nodes are independent | ||
from each other. | ||
|
||
Parameters | ||
---------- | ||
outputs : Collection[Variable] | ||
Variable to get conditions for | ||
ancestors_to_include : Optional[Collection[Variable]] | ||
Additional ancestors to assume, by default None | ||
|
||
Returns | ||
------- | ||
List[Variable] | ||
Variables required to compute ``outputs`` | ||
|
||
Examples | ||
-------- | ||
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o`` | ||
|
||
* No ancestors to include | ||
|
||
.. code-block:: | ||
|
||
n - n - (o) | ||
|
||
* One ancestors to include | ||
|
||
.. code-block:: | ||
|
||
n - (c) - o | ||
|
||
* Two ancestors to include where on depends on another, both returned | ||
|
||
.. code-block:: | ||
|
||
(c) - (c) - o | ||
|
||
* Additional nodes are present | ||
|
||
.. code-block:: | ||
|
||
(c) - n - o | ||
n - (n) -' | ||
|
||
* Disconnected ancestors to include not returned | ||
|
||
.. code-block:: | ||
|
||
(c) - n - o | ||
c | ||
|
||
* Disconnected output is present and returned | ||
|
||
.. code-block:: | ||
|
||
(c) - (c) - o | ||
(o) | ||
|
||
* ancestors to include that include itself adds itself | ||
|
||
.. code-block:: | ||
|
||
n - (c) - (o/c) | ||
|
||
""" | ||
# simple case, no additional ancestors to include | ||
truncated_inputs = list() | ||
# blockers have known independent nodes and ancestors to include | ||
candidates = list(outputs) | ||
if not ancestors_to_include: # None or empty | ||
# just filter out unique variables | ||
for node in candidates: | ||
if node not in truncated_inputs: | ||
truncated_inputs.append(node) | ||
# no more actions are needed | ||
return truncated_inputs | ||
blockers: Set[Variable] = set(ancestors_to_include) | ||
# enforce O(1) check for node in ancestors to include | ||
ancestors_to_include = blockers.copy() | ||
|
||
while candidates: | ||
# on any new candidate | ||
node = candidates.pop() | ||
# check if the node is independent, never go above blockers | ||
# blockers are independent nodes and ancestors to include | ||
if node in ancestors_to_include: | ||
# The case where node is in ancestors to include so we check if it depends on others | ||
# it should be removed from the blockers to check against the rest | ||
dependent = variable_depends_on(node, blockers - {node}) | ||
# ancestors to include that are present in the graph (not disconnected) | ||
# should be added to truncated_inputs | ||
truncated_inputs.append(node) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ferrine I think we can end up with duplicated inputs here. If another node has duplicated truncated inputs? |
||
if dependent: | ||
# if the ancestors to include is still dependent we need to go above, | ||
# the search is not yet finished | ||
# the node _has_ to have owner to be dependent | ||
# so we do not check it | ||
# and populate search to go above | ||
# owner can never be None for a dependent node | ||
candidates.extend(node.owner.inputs) | ||
ferrine marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# A regular node to check | ||
dependent = variable_depends_on(node, blockers) | ||
# all regular nodes fall to blockes | ||
# 1. it is dependent - further search irrelevant | ||
# 2. it is independent - the search node is inside the closure | ||
blockers.add(node) | ||
# if we've found an independent node and it is not in blockers so far | ||
# it is a new indepenent node not present in ancestors to include | ||
if not dependent: | ||
# we've found an independent node | ||
# do not search beyond | ||
truncated_inputs.append(node) | ||
else: | ||
# populate search otherwise | ||
# owner can never be None for a dependent node | ||
candidates.extend(node.owner.inputs) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return truncated_inputs | ||
|
||
|
||
def clone( | ||
inputs: List[Variable], | ||
outputs: List[Variable], | ||
|
@@ -1151,53 +1281,6 @@ def clone_get_equiv( | |
return memo | ||
|
||
|
||
def clone_replace( | ||
output: Collection[Variable], | ||
replace: Optional[ | ||
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]] | ||
] = None, | ||
**rebuild_kwds, | ||
) -> List[Variable]: | ||
"""Clone a graph and replace subgraphs within it. | ||
|
||
It returns a copy of the initial subgraph with the corresponding | ||
substitutions. | ||
|
||
Parameters | ||
---------- | ||
output | ||
PyTensor expression that represents the computational graph. | ||
replace | ||
Dictionary describing which subgraphs should be replaced by what. | ||
rebuild_kwds | ||
Keywords to `rebuild_collect_shared`. | ||
|
||
""" | ||
from pytensor.compile.function.pfunc import rebuild_collect_shared | ||
|
||
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]] | ||
if isinstance(replace, dict): | ||
items = list(replace.items()) | ||
elif isinstance(replace, (list, tuple)): | ||
items = replace | ||
elif replace is None: | ||
items = [] | ||
else: | ||
raise ValueError( | ||
"replace is neither a dictionary, list, " | ||
f"tuple or None ! The value provided is {replace}," | ||
f"of type {type(replace)}" | ||
) | ||
tmp_replace = [(x, x.type()) for x, y in items] | ||
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] | ||
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds) | ||
|
||
# TODO Explain why we call it twice ?! | ||
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds) | ||
|
||
return cast(List[Variable], outs) | ||
|
||
|
||
def general_toposort( | ||
outputs: Iterable[T], | ||
deps: Callable[[T], Union[OrderedSet, List[T]]], | ||
|
@@ -1615,37 +1698,63 @@ def expand(o: Apply) -> List[Apply]: | |
) | ||
|
||
|
||
def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool: | ||
"""Determine if `f_apply` is in the graph given by `l_apply`. | ||
def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) -> bool: | ||
"""Determine if any `depends_on` is in the graph given by ``apply``. | ||
|
||
Parameters | ||
---------- | ||
l_apply : Apply | ||
The node to walk. | ||
f_apply : Apply | ||
The node to find in `l_apply`. | ||
apply : Apply | ||
The Apply node to check. | ||
depends_on : Union[Apply, Collection[Apply]] | ||
Apply nodes to check dependency on | ||
|
||
Returns | ||
------- | ||
bool | ||
|
||
""" | ||
computed = set() | ||
todo = [l_apply] | ||
todo = [apply] | ||
if not isinstance(depends_on, Collection): | ||
depends_on = {depends_on} | ||
else: | ||
depends_on = set(depends_on) | ||
while todo: | ||
cur = todo.pop() | ||
if cur.outputs[0] in computed: | ||
continue | ||
if all(i in computed or i.owner is None for i in cur.inputs): | ||
computed.update(cur.outputs) | ||
if cur is f_apply: | ||
if cur in depends_on: | ||
return True | ||
else: | ||
todo.append(cur) | ||
todo.extend(i.owner for i in cur.inputs if i.owner) | ||
return False | ||
|
||
|
||
def variable_depends_on( | ||
variable: Variable, depends_on: Union[Variable, Collection[Variable]] | ||
) -> bool: | ||
"""Determine if any `depends_on` is in the graph given by ``variable``. | ||
Parameters | ||
---------- | ||
variable: Variable | ||
Node to check | ||
depends_on: Collection[Variable] | ||
Nodes to check dependency on | ||
|
||
Returns | ||
------- | ||
bool | ||
""" | ||
if not isinstance(depends_on, Collection): | ||
depends_on = {depends_on} | ||
else: | ||
depends_on = set(depends_on) | ||
return any(interim in depends_on for interim in ancestors([variable])) | ||
|
||
|
||
def equal_computations( | ||
xs: List[Union[np.ndarray, Variable]], | ||
ys: List[Union[np.ndarray, Variable]], | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.