Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 20, 2024

Description

PyTensor uses DisconnectedType and NullType variables to raise informative errors when users request gradients wrt to inputs that can't be computed. This is a problem for OpFromGraph which may include parallel graphs, some of which are disconnected/null and others not. We don't want to fail when the user only needs the gradient that's supported.

There was already some special logic before, to handle cases where NullType and DisconnectedType arise from the OFG inner graph. Instead of outputing those types (which OFG cannot produce out of thin air, as they are root variables), we were outputing dummy zeros, and then masking those with the original NullType or DisconnectedType variables created in the internal call to grad/Rop. This seems reasonable if only a bit tedious. This PR first refactors this code to avoid the dummy outputs altogether (there's no reason for them!).

Then it extends this logic to also handle cases where DisconnectedType (but not NullType) arise before the inner graph of OpFromGraph. This was the case behind one of the issues described in #1. When an OFG has multiple outputs, and the requested gradient only uses a subset, PyTensor will feed DisconnectedType variables in place of the output_gradients used by the L_op. The solution to this problem is to filter out these unused input variables. This should be safe, in that if the inner graph of the OFG needs to use these variables and we don't provide them, it will create new DisconnectedTypes on the fly. The pre-existing filtering will then kick in.

This however means we may need distinct OFG from different patterns of disconnected gradients. Accordingly, the cache is now done per pattern.

I suspect this is the issue behind #652

Question: Do we really need to cache stuff?

This PR also deprecates grad_overrides and some options of lop_rop overrides, as well as custom logic for invalid connection_patterns. Hopefully this helps us making OpFromGrah more maintainable.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added bug Something isn't working maintenance OpFromGraph labels Apr 20, 2024
@ricardoV94 ricardoV94 changed the title Fix multiple OpFromGrad gradient issues Fix OpFromGraph L_op when output gradients are disconnected Apr 20, 2024
@ricardoV94 ricardoV94 changed the title Fix OpFromGraph L_op when output gradients are disconnected Fix OpFromGraph with disconnected output gradients Apr 20, 2024
@ricardoV94 ricardoV94 force-pushed the fix_OFG_grad branch 2 times, most recently from ad37756 to 869cd5f Compare April 20, 2024 20:23
Copy link

codecov bot commented Apr 20, 2024

Codecov Report

Attention: Patch coverage is 79.50820% with 25 lines in your changes missing coverage. Please review.

Project coverage is 80.94%. Comparing base (fc21336) to head (a96e5a1).
Report is 208 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/compile/builders.py 79.50% 14 Missing and 11 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #723      +/-   ##
==========================================
+ Coverage   80.85%   80.94%   +0.08%     
==========================================
  Files         162      162              
  Lines       47043    46945      -98     
  Branches    11514    11481      -33     
==========================================
- Hits        38038    37998      -40     
+ Misses       6750     6706      -44     
+ Partials     2255     2241      -14     
Files with missing lines Coverage Δ
pytensor/gradient.py 77.37% <ø> (+0.54%) ⬆️
pytensor/compile/builders.py 88.38% <79.50%> (+10.93%) ⬆️

... and 4 files with indirect coverage changes

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superficial first pass across the PR. I cannot make informed comment about the actual meat of the changes until I fire up a debugger and try to grok what OpFromGraph is actually doing. I will make an effort to do this in the next 48 hours and give a more meaningful review.

connected_output_grads = [
out_grad
for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't output_grads need to check for NullType?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly because I am not sure when Nulltype actually arise.

I prefer the special logic to be as specific as possible, we can reassess if NullTypes also show up in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the same logic should apply yeah, let's put Null here as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't want to filter those out. If we know that output gradient is Null, we shouldn't omit that information from the inner gradient graph generation. If we omit grad assumes it's simply disconnected.

Omitting disconnected makes sense, because they will default again to disconnected_types if needed by grad

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 21, 2024

Superficial first pass across the PR. I cannot make informed comment about the actual meat of the changes until I fire up a debugger and try to grok what OpFromGraph is actually doing. I will make an effort to do this in the next 48 hours and give a more meaningful review.

Thanks! It may help to convince yourself that no behavior was changed until commit -2 where the bug fix is done (other than deprecations and removal of special behavior in connection pattern)

@ricardoV94 ricardoV94 force-pushed the fix_OFG_grad branch 2 times, most recently from a0272f7 to 71ec299 Compare May 29, 2024 10:28
@ricardoV94
Copy link
Member Author

ricardoV94 commented May 29, 2024

I found another issue, if the outputs of an OpFromGraph are not independent, the existing logic fails in that instead of adding the contributions coming from each output, it overrides due to how known_grads we are using internally behaves.

The new test cases in the last commit illustrate this. Any case that depends on out3 fails numerically because we ignore/mask the contributions coming from it.

x, y = dscalars("x", "y")
rng = np.random.default_rng(594)
point = list(rng.normal(size=(2,)))

out1 = x + y
out2 = x * y
out3 = out1 + out2  # Create dependency between outputs
op = OpFromGraph([x, y], [out1, out2, out3])
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)

If instead we defined out3 explicitly as out3 = (x + y) * (x * y) it works fine again

@aseyboldt any idea how we could handle this? In an outer function I think this would be handled by adding the direct contributions to out1/out2 with the inderect ones coming from out3

It seems like I want to initialize those variable grads to the output_grad values, but still allow them to be updated, and not setting them as known which doesn't allow any further updates?

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 29, 2024

Found a nice(?) hack. Instead of calling Lop internally with known_grads=dict(zip(inner_outputs, output_gradients)) I do it with known_grads(dict(zip(identity_inner_outputs, output_gradients)) where identity_inner_outputs is each inner_output wrapped in a dummy Identity operation. This way we correctly accumulate direct and indirect contributions coming from other inner outputs

@ricardoV94 ricardoV94 force-pushed the fix_OFG_grad branch 2 times, most recently from 90d9601 to 335d2e0 Compare May 29, 2024 10:55
@ricardoV94 ricardoV94 changed the title Fix OpFromGraph with disconnected output gradients Fix OpFromGraph with disconnected/related outputs May 29, 2024
@ricardoV94 ricardoV94 changed the title Fix OpFromGraph with disconnected/related outputs Fix gradient of OpFromGraph with disconnected/related outputs May 29, 2024
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really say that I fully understand the implications of the changes, but it certainly seems like an improvement, so unless someone wants to do a more thorough review, I think we should merge this.

@ricardoV94 ricardoV94 merged commit 2143d85 into pymc-devs:main May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient of OpFromGraph fails

3 participants