-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[Kernel Fusion] training benchmarks of AOTAutograd (multiple models) #15264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
scripts/aot_albert.py
Outdated
| print("Maximum output error: ", (out1 - out2).abs().max().item()) | ||
| print("Maximum gradient error: ", max([(a-b).abs().max() for a, b in zip(grad1, grad2) if a is not None]).item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may I suggest to actually set the expected tolerance and do torch.testing.assert_close check, since it's hard to notice visually when something is wrong, especially now that we have a lot more of these. and then it just becomes noise, so all that printing can be removed altogether, unless there is a mismatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Just been there and torch.testing.assert_close is pretty handy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, any thoughts on rtol/atol values? It's kind of difficult to test these kinds of gradients on whole models, since catastrophic cancellation can lead to relatively large atol values and very large rtol values, especially for fp16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated with some tolerance checks, but they're much looser than I would prefer lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hoping you had a much better sense of what the ideal values would be. Let's start with your proposed setup and then adapt as needed.
The next level of testing is to run some scoring function like glue on a real data and compare the output is the same or is very similar. but in other ways it can be less precise as well. let's keep on experimenting.
|
At least for now we should probably not assert on failure, since the rest of the benchmark then doesn't run and requires re-running the whole thing - perhaps reporting a mismatch but continuing the execution? This is trying to re-run the updated script on my rtx-3090: And also usability-wise the error doesn't say for which config it failed, so probably want to dump the setup info before running it? Thank you! |
|
@stas00 updated to accumulate the numerical failures and print them out at the end. Numerical checking for whole models is super finicky, sadly. |
Great, thanks a lot! Probably still want to print out the combo as currently it prints: which is meaningless to the user of the benchmark since 30 doesn't appear anywhere in the code or output :) a small nit. I updated the OP with the output on RTX-3090 |
|
@stas00 It does print out the combo (of the name + dtype), although perhaps could be more obvious. Knowing that it's |
|
Oh, I see it now, I guess the double new line made me not look up as it was all the intermediary prints until then. please don't worry about it, we can polish it at the end, now that I know where to look it's all good. |
Note to maintainers: We are using this PR to collaborate and there is no intention yet to merge anything, so please ignore unless you want to experiment with the latest auto-speedups.
We are experimenting with the latest https://github.com/pytorch/functorch against pytorch nightly to automatically speed up the execution and reduce memory usage:
So the idea is this. Given an existing
model, you speed it up by doing just this:So for example HF Trainer could automate this with just adding a new flag, like
--fusion aot.Notably, as long as the part being compiled with AOTAutograd is static, you can do whatever you want outside of the model, and autograd will still work with the AOTAutograd compiled model.
So, things like this work fine
Here are some benchmarks:
A100 training (from @Chillee):
On rtx3090 (from @stas00):
Instructions from @stas00 on how to build functorch to reproduce these results.
As this is a constantly evolving code-base, make sure to
git pulland rebuild above if you have the version that is some days old. Or at least if you try the code and it fails the first thing to do is to update and rebuildfunctorchand then retry the benchmarks.Note that there is currently a correctness issue on one of the gradients on PyTorch nightly, the above was run with this patch (pytorch/pytorch#71542), which fixes the correctness issue.
Notes:
Q: What is the pytree registration here for?
A: AOTAutograd tries to present the simplest possible graphs for backends, and so it primarily works with lists of tensors for both the input and the output. So, these pytrees are needed so that we can flatten the input data structures into a list, and unflatten the output back into the correct data structure. PS: This is very much inspired by Jax <3
Resources on AOTAutograd:
AOTAutograd: https://docs.google.com/presentation/d/1rTt0BR2KChDQQTks2hHUtvHxtHQKwgQHVNrmbhj0byk/edit?usp=sharing
Min-Cut recomputation: https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467