Skip to content

ENH: implement graph substitution that preserves independent parts #19

@ferrine

Description

@ferrine

Link to a discussion

An example implementation

https://gist.github.com/ferrine/70cbcf6d3b6f033ac070d70b10ac8d25

Before

import pytensor.tensor as at
import pytensor
import numpy as np

a = at.scalar("a")
b = at.scalar("b")
b2 = b * 2
a2 = a * 2

d = (a2 ** 2 + b2 ** 2).flatten()

assert is_in_ancestors(b2, [b])
assert is_in_ancestors(d, [b])
assert not is_in_ancestors(a2, [b])
assert a in independent_apply_nodes_between([b], [d])
assert a2 in independent_apply_nodes_between([b], [d])
assert b2 not in independent_apply_nodes_between([b], [d])
d_clone = pytensor.clone_replace([d], {b: b.clone()})[0]
assert not is_in_ancestors(d_clone, [b2])
assert is_in_ancestors(d_clone, [a])
assert is_in_ancestors(d_clone, [a2]) # fails, only inputs are preserved, the remaining subgraph is copied and node references are broken

After

import pytensor.tensor as at
import pytensor
import numpy as np
a = at.scalar("a")
b = at.scalar("b")
b2 = b * 2
a2 = a * 2

d = (a2 ** 2 + b2 ** 2).flatten()

assert is_in_ancestors(b2, [b])
assert is_in_ancestors(d, [b])
assert not is_in_ancestors(a2, [b])
assert a in independent_apply_nodes_between([b], [d])
assert a2 in independent_apply_nodes_between([b], [d])
assert b2 not in independent_apply_nodes_between([b], [d])
# graph_substitute name is just an example, we can get a better name
d_clone = pytensor.graph_substitute([d], {b: b.clone()})[0]
assert not is_in_ancestors(d_clone, [b2])
assert is_in_ancestors(d_clone, [a])
assert is_in_ancestors(d_clone, [a2])

Context for the issue:

@lucianopaz had an example where random variables that get their references lost and the process to recreate them was cumbersome

import pytensor.tensor as at
import pytensor
import numpy as np

a = at.random.normal(loc=3, scale=0.01, name="a", size=2)
b = at.random.normal(loc=1, scale=0.01, name="b", size=(2, 2))
c = at.random.normal(loc=100, scale=0.01, name="c", size=(2, 2, 2))
d = at.random.normal(loc=(a + b + c).flatten(), scale=0.01)
d.name = "d"
d_clone = pytensor.graph.basic.clone_replace(
    [d], replace={c: at.zeros(c.shape, dtype=c.dtype)}
)[0]
f = pytensor.function([a, b], d_clone, on_unused_input="ignore")  # because impossible to reference a and b
f(a=np.zeros(2), b=np.zeros((2, 2)))

The proposed alternative graph substitution solves this issue

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions