-
Notifications
You must be signed in to change notification settings - Fork 143
Replace RNG update in RV lift rewrites #870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Otherwise we end up with multiple RVs if the RNGs are an output / used elsewhere in the function
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #870 +/- ##
=======================================
Coverage 80.98% 80.99%
=======================================
Files 169 169
Lines 46985 46988 +3
Branches 11494 11495 +1
=======================================
+ Hits 38052 38057 +5
- Misses 6716 6718 +2
+ Partials 2217 2213 -4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is Ok. But in all honesty, someone else should also have a look. I couldn't spend the necessary time to dig deep into the stuff that I didn't fully understand. I left a few questions though
|
||
# We replace uses of the dimshuffled RV by the new RV | ||
# And uses of the old RNG update by the new RNG update | ||
return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with this type of output. I've always seen lists as return types. What happens with dictionaries? Are they taken as a sort of updates
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dictionary allows you to replace any keys by the values. It's an alternative to the default you're used to, which would be something like dict(zip(node.outputs, new_outputs))
constant(1, dtype="int64") if o == "x" else size[o] | ||
for o in batched_dims_ds_order | ||
] | ||
new_size = size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were you able to remove all of the original code that was visited in this branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized there's no need to introduce size if it was not defined in the original RV
|
||
rv_op = rv_node.op | ||
rng, size, *dist_params = rv_node.inputs | ||
next_rng, rv = rv_node.outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why could you remove the shape_feature
stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't remove but restricted it to the only case that's needed, which is when the RV had a size defined.
It shows up in the branch below
else: | ||
shape_feature = getattr(fgraph, "shape_feature", None) | ||
if not shape_feature: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this really return None
? If yes, add a # pragma: no cover
to ignore the missing line coverage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should return None, we need the feature for this rewrite case to work properly
# Check that neither the RV nor the old Subtensor are in the shape graph. | ||
output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None) | ||
if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)): | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both were the behavior before so it's not a new thing
Description
By only replacing the output variable (dimshuffle, subtensor), we were leaving the second output of the lifted RV (the rng update) linked to the old node. In a normal function this RNG update is an output of the function and we want to also replace it to link to the new RV node.
Not doing this would result in "invalid random graphs", with the same RNG leading to two RVs of different sizes/shapes.
Also simplified a bit the rewrites in the first commit, now that RandomVariable always adds expand dims to the batch dimensions of the parameters
Related Issue
Checklist
Type of change