-
Notifications
You must be signed in to change notification settings - Fork 143
Closed
Labels
Description
Link to a discussion
An example implementation
https://gist.github.com/ferrine/70cbcf6d3b6f033ac070d70b10ac8d25
Before
import pytensor.tensor as at
import pytensor
import numpy as np
a = at.scalar("a")
b = at.scalar("b")
b2 = b * 2
a2 = a * 2
d = (a2 ** 2 + b2 ** 2).flatten()
assert is_in_ancestors(b2, [b])
assert is_in_ancestors(d, [b])
assert not is_in_ancestors(a2, [b])
assert a in independent_apply_nodes_between([b], [d])
assert a2 in independent_apply_nodes_between([b], [d])
assert b2 not in independent_apply_nodes_between([b], [d])
d_clone = pytensor.clone_replace([d], {b: b.clone()})[0]
assert not is_in_ancestors(d_clone, [b2])
assert is_in_ancestors(d_clone, [a])
assert is_in_ancestors(d_clone, [a2]) # fails, only inputs are preserved, the remaining subgraph is copied and node references are broken
After
import pytensor.tensor as at
import pytensor
import numpy as np
a = at.scalar("a")
b = at.scalar("b")
b2 = b * 2
a2 = a * 2
d = (a2 ** 2 + b2 ** 2).flatten()
assert is_in_ancestors(b2, [b])
assert is_in_ancestors(d, [b])
assert not is_in_ancestors(a2, [b])
assert a in independent_apply_nodes_between([b], [d])
assert a2 in independent_apply_nodes_between([b], [d])
assert b2 not in independent_apply_nodes_between([b], [d])
# graph_substitute name is just an example, we can get a better name
d_clone = pytensor.graph_substitute([d], {b: b.clone()})[0]
assert not is_in_ancestors(d_clone, [b2])
assert is_in_ancestors(d_clone, [a])
assert is_in_ancestors(d_clone, [a2])
Context for the issue:
@lucianopaz had an example where random variables that get their references lost and the process to recreate them was cumbersome
import pytensor.tensor as at
import pytensor
import numpy as np
a = at.random.normal(loc=3, scale=0.01, name="a", size=2)
b = at.random.normal(loc=1, scale=0.01, name="b", size=(2, 2))
c = at.random.normal(loc=100, scale=0.01, name="c", size=(2, 2, 2))
d = at.random.normal(loc=(a + b + c).flatten(), scale=0.01)
d.name = "d"
d_clone = pytensor.graph.basic.clone_replace(
[d], replace={c: at.zeros(c.shape, dtype=c.dtype)}
)[0]
f = pytensor.function([a, b], d_clone, on_unused_input="ignore") # because impossible to reference a and b
f(a=np.zeros(2), b=np.zeros((2, 2)))
The proposed alternative graph substitution solves this issue