Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
__api_version__ = 1

# isort: off
from pytensor.graph.basic import Variable, clone_replace
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace, graph_replace

# isort: on

Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
Constant,
NominalVariable,
Variable,
clone_replace,
graph_inputs,
io_connection_pattern,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
Constant,
graph_inputs,
clone,
clone_replace,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
Expand Down
219 changes: 164 additions & 55 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,136 @@ def applys_between(
)


def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Optional[Collection[Variable]] = None,
) -> List[Variable]:
"""Get the truncate graph inputs.

Unlike :func:`graph_inputs` this function will return
the closest nodes to outputs that do not depend on
``ancestors_to_include``. So given all the returned
variables provided there is no missing node to
compute the output and all nodes are independent
from each other.

Parameters
----------
outputs : Collection[Variable]
Variable to get conditions for
ancestors_to_include : Optional[Collection[Variable]]
Additional ancestors to assume, by default None

Returns
-------
List[Variable]
Variables required to compute ``outputs``

Examples
--------
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``

* No ancestors to include

.. code-block::

n - n - (o)

* One ancestors to include

.. code-block::

n - (c) - o

* Two ancestors to include where on depends on another, both returned

.. code-block::

(c) - (c) - o

* Additional nodes are present

.. code-block::

(c) - n - o
n - (n) -'

* Disconnected ancestors to include not returned

.. code-block::

(c) - n - o
c

* Disconnected output is present and returned

.. code-block::

(c) - (c) - o
(o)

* ancestors to include that include itself adds itself

.. code-block::

n - (c) - (o/c)

"""
# simple case, no additional ancestors to include
truncated_inputs = list()
# blockers have known independent nodes and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
# just filter out unique variables
for node in candidates:
if node not in truncated_inputs:
truncated_inputs.append(node)
# no more actions are needed
return truncated_inputs
blockers: Set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include
ancestors_to_include = blockers.copy()

while candidates:
# on any new candidate
node = candidates.pop()
# check if the node is independent, never go above blockers
# blockers are independent nodes and ancestors to include
if node in ancestors_to_include:
# The case where node is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, blockers - {node})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ferrine I think we can end up with duplicated inputs here. If another node has duplicated truncated inputs?

if dependent:
# if the ancestors to include is still dependent we need to go above,
# the search is not yet finished
# the node _has_ to have owner to be dependent
# so we do not check it
# and populate search to go above
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
else:
# A regular node to check
dependent = variable_depends_on(node, blockers)
# all regular nodes fall to blockes
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure
blockers.add(node)
# if we've found an independent node and it is not in blockers so far
# it is a new indepenent node not present in ancestors to include
if not dependent:
# we've found an independent node
# do not search beyond
truncated_inputs.append(node)
else:
# populate search otherwise
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
return truncated_inputs


def clone(
inputs: List[Variable],
outputs: List[Variable],
Expand Down Expand Up @@ -1151,53 +1281,6 @@ def clone_get_equiv(
return memo


def clone_replace(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.

It returns a copy of the initial subgraph with the corresponding
substitutions.

Parameters
----------
output
PyTensor expression that represents the computational graph.
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.

"""
from pytensor.compile.function.pfunc import rebuild_collect_shared

items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)

# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)

return cast(List[Variable], outs)


def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]],
Expand Down Expand Up @@ -1615,37 +1698,63 @@ def expand(o: Apply) -> List[Apply]:
)


def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
"""Determine if `f_apply` is in the graph given by `l_apply`.
def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.

Parameters
----------
l_apply : Apply
The node to walk.
f_apply : Apply
The node to find in `l_apply`.
apply : Apply
The Apply node to check.
depends_on : Union[Apply, Collection[Apply]]
Apply nodes to check dependency on

Returns
-------
bool

"""
computed = set()
todo = [l_apply]
todo = [apply]
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur is f_apply:
if cur in depends_on:
return True
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return False


def variable_depends_on(
variable: Variable, depends_on: Union[Variable, Collection[Variable]]
) -> bool:
"""Determine if any `depends_on` is in the graph given by ``variable``.
Parameters
----------
variable: Variable
Node to check
depends_on: Collection[Variable]
Nodes to check dependency on

Returns
-------
bool
"""
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
return any(interim in depends_on for interim in ancestors([variable]))


def equal_computations(
xs: List[Union[np.ndarray, Variable]],
ys: List[Union[np.ndarray, Variable]],
Expand Down
Loading