Skip to content

Conversation

Chillee
Copy link

@Chillee Chillee commented Jan 21, 2022

Note to maintainers: We are using this PR to collaborate and there is no intention yet to merge anything, so please ignore unless you want to experiment with the latest auto-speedups.

We are experimenting with the latest https://github.com/pytorch/functorch against pytorch nightly to automatically speed up the execution and reduce memory usage:

So the idea is this. Given an existing model, you speed it up by doing just this:

from functorch.compile import memory_efficient_fusion
aot_model = memory_efficient_fusion(model)
with torch.jit.fuser("fuser2"):
    train(aot_model)

So for example HF Trainer could automate this with just adding a new flag, like --fusion aot.

Notably, as long as the part being compiled with AOTAutograd is static, you can do whatever you want outside of the model, and autograd will still work with the AOTAutograd compiled model.

So, things like this work fine

foo = aot_model(*inps)
loss = foo.sum()
if loss < 0:
    print("awesome")
loss.backward()

Here are some benchmarks:


A100 training (from @Chillee):

$ python  scripts/aot_albert.py
Current memory requirement: 5.69 GB
eager 0.08250337839126587
Current memory requirement: 4.00 GB
aot 0.05442763566970825

Maximum output error:  9.5367431640625e-07
Maximum gradient error:  1.5096273273229599e-05
model dtype name time (s) mem (GB) time % mem %
AlbertForMaskedLM torch.float32 eager 0.087 7.133 0 0
AlbertForMaskedLM torch.float32 aot 0.057 5.438 -35 -24
AlbertForMaskedLM torch.float16 eager 0.051 3.901 0 0
AlbertForMaskedLM torch.float16 aot 0.034 3.054 -34 -22
AlbertForMaskedLM torch.bfloat16 eager 0.053 3.931 0 0
AlbertForMaskedLM torch.bfloat16 aot 0.034 3.083 -36 -22
GPT2LMHeadModel torch.float32 eager 0.056 5.174 0 0
GPT2LMHeadModel torch.float32 aot 0.045 4.328 -19 -16
GPT2LMHeadModel torch.float16 eager 0.033 4.645 0 0
GPT2LMHeadModel torch.float16 aot 0.029 4.223 -13 -9
GPT2LMHeadModel torch.bfloat16 eager 0.034 4.965 0 0
GPT2LMHeadModel torch.bfloat16 aot 0.029 4.541 -15 -9
BertForMaskedLM torch.float32 eager 0.041 6.764 0 0
BertForMaskedLM torch.float32 aot 0.036 6.759 -13 0
BertForMaskedLM torch.float16 eager 0.025 6.228 0 0
BertForMaskedLM torch.float16 aot 0.021 6.226 -16 0
BertForMaskedLM torch.bfloat16 eager 0.026 6.505 0 0
BertForMaskedLM torch.bfloat16 aot 0.021 6.503 -19 0
LongformerForMaskedLM torch.float32 eager 0.122 9.921 0 0
LongformerForMaskedLM torch.float32 aot 0.111 9.933 -9 0

On rtx3090 (from @stas00):

model dtype name time (s) mem (GB) time % mem %
AlbertForMaskedLM torch.float32 eager 0.173 7.078 0 0
AlbertForMaskedLM torch.float32 aot 0.125 5.382 -28 -24
AlbertForMaskedLM torch.float16 eager 0.089 3.829 0 0
AlbertForMaskedLM torch.float16 aot 0.064 2.982 -28 -22
AlbertForMaskedLM torch.bfloat16 eager 0.092 3.852 0 0
AlbertForMaskedLM torch.bfloat16 aot 0.064 3.005 -30 -22
GPT2LMHeadModel torch.float32 eager 0.112 4.822 0 0
GPT2LMHeadModel torch.float32 aot 0.094 3.977 -16 -18
GPT2LMHeadModel torch.float16 eager 0.060 4.013 0 0
GPT2LMHeadModel torch.float16 aot 0.051 3.591 -15 -11
GPT2LMHeadModel torch.bfloat16 eager 0.061 4.736 0 0
GPT2LMHeadModel torch.bfloat16 aot 0.051 4.313 -16 -9
BertForMaskedLM torch.float32 eager 0.086 6.343 0 0
BertForMaskedLM torch.float32 aot 0.076 6.338 -11 0
BertForMaskedLM torch.float16 eager 0.046 5.717 0 0
BertForMaskedLM torch.float16 aot 0.041 5.714 -11 0
BertForMaskedLM torch.bfloat16 eager 0.046 5.952 0 0
BertForMaskedLM torch.bfloat16 aot 0.040 5.950 -13 0
LongformerForMaskedLM torch.float32 eager 0.209 9.080 0 0
LongformerForMaskedLM torch.float32 aot 0.194 9.092 -7 0

Instructions from @stas00 on how to build functorch to reproduce these results.

# install torch-nightly
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly

# install functorch (and reinstall after `git pull` later if need to sync up)
git clone https://github.com/pytorch/functorch
cd functorch
rm -rf build
pip install -e .[aot]

As this is a constantly evolving code-base, make sure to git pull and rebuild above if you have the version that is some days old. Or at least if you try the code and it fails the first thing to do is to update and rebuild functorch and then retry the benchmarks.

Note that there is currently a correctness issue on one of the gradients on PyTorch nightly, the above was run with this patch (pytorch/pytorch#71542), which fixes the correctness issue.


Notes:

  • AOT = Ahead of Time
  • eager = normal python/pytorch code - i.e. the way our models are written now

Q: What is the pytree registration here for?
A: AOTAutograd tries to present the simplest possible graphs for backends, and so it primarily works with lists of tensors for both the input and the output. So, these pytrees are needed so that we can flatten the input data structures into a list, and unflatten the output back into the correct data structure. PS: This is very much inspired by Jax <3

Resources on AOTAutograd:
AOTAutograd: https://docs.google.com/presentation/d/1rTt0BR2KChDQQTks2hHUtvHxtHQKwgQHVNrmbhj0byk/edit?usp=sharing
Min-Cut recomputation: https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Comment on lines 109 to 110
print("Maximum output error: ", (out1 - out2).abs().max().item())
print("Maximum gradient error: ", max([(a-b).abs().max() for a, b in zip(grad1, grad2) if a is not None]).item())
Copy link
Contributor

@stas00 stas00 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may I suggest to actually set the expected tolerance and do torch.testing.assert_close check, since it's hard to notice visually when something is wrong, especially now that we have a lot more of these. and then it just becomes noise, so all that printing can be removed altogether, unless there is a mismatch.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Just been there and torch.testing.assert_close is pretty handy.

Copy link
Author

@Chillee Chillee Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, any thoughts on rtol/atol values? It's kind of difficult to test these kinds of gradients on whole models, since catastrophic cancellation can lead to relatively large atol values and very large rtol values, especially for fp16.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with some tolerance checks, but they're much looser than I would prefer lol.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping you had a much better sense of what the ideal values would be. Let's start with your proposed setup and then adapt as needed.

The next level of testing is to run some scoring function like glue on a real data and compare the output is the same or is very similar. but in other ways it can be less precise as well. let's keep on experimenting.

@stas00 stas00 changed the title Added example of using AOTAutograd with HuggingFace Albert for training [Kernel Fusion] training benchmarks of AOTAutograd (multiple models) Jan 26, 2022
@stas00
Copy link
Contributor

stas00 commented Jan 26, 2022

At least for now we should probably not assert on failure, since the rest of the benchmark then doesn't run and requires re-running the whole thing - perhaps reporting a mismatch but continuing the execution?

This is trying to re-run the updated script on my rtx-3090:

Traceback (most recent call last):
  File "scripts/aot_albert.py", line 119, in <module>
    torch.testing.assert_close(grad2, grad1, atol=atol, rtol=rtol)
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1255, in assert_close
    assert_equal(
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1030, in assert_equal
    raise error_metas[0].to_error()
AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 38597376 (0.0%)
Greatest absolute difference: 7.62038107495755e-05 at index (18699, 512) (up to 5e-05 allowed)
Greatest relative difference: 0.16173266498354133 at index (18699, 512) (up to 0.005 allowed)

And also usability-wise the error doesn't say for which config it failed, so probably want to dump the setup info before running it?

Thank you!

@Chillee
Copy link
Author

Chillee commented Jan 27, 2022

@stas00 updated to accumulate the numerical failures and print them out at the end.

Numerical checking for whole models is super finicky, sadly.

@stas00
Copy link
Contributor

stas00 commented Jan 27, 2022

@stas00 updated to accumulate the numerical failures and print them out at the end.

Numerical checking for whole models is super finicky, sadly.

Great, thanks a lot! Probably still want to print out the combo as currently it prints:

The failure occurred for item [30]

which is meaningless to the user of the benchmark since 30 doesn't appear anywhere in the code or output :) a small nit.

I updated the OP with the output on RTX-3090

@Chillee
Copy link
Author

Chillee commented Jan 27, 2022

@stas00 It does print out the combo (of the name + dtype), although perhaps could be more obvious.

Knowing that it's item [30] is actually useful - it's the 30th gradient value in the list of gradient checks.

@stas00
Copy link
Contributor

stas00 commented Jan 27, 2022

Oh, I see it now, I guess the double new line made me not look up as it was all the intermediary prints until then.

please don't worry about it, we can polish it at the end, now that I know where to look it's all good.

@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 20, 2022
@huggingface huggingface deleted a comment from github-actions bot Feb 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Kernel Fusion Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants