Skip to content

Conversation

@Akababa
Copy link
Contributor

@Akababa Akababa commented Feb 2, 2020

Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance.

@codecov-io
Copy link

codecov-io commented Feb 2, 2020

Codecov Report

Merging #2715 into master will decrease coverage by 0.01%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2715      +/-   ##
==========================================
- Coverage    77.8%   77.79%   -0.02%     
==========================================
  Files         100      100              
  Lines       17051    17052       +1     
==========================================
- Hits        13267    13266       -1     
- Misses       3784     3786       +2
Impacted Files Coverage Δ
src/transformers/modeling_gpt2.py 86.2% <100%> (+0.04%) ⬆️
src/transformers/modeling_tf_utils.py 88.15% <0%> (-0.18%) ⬇️
src/transformers/modeling_utils.py 91.81% <0%> (-0.14%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 33ef700...a54a418. Read the comment docs.

@julien-c
Copy link
Member

julien-c commented Feb 3, 2020

Thanks – what's the PyTorch compatibility on this?

@Akababa
Copy link
Contributor Author

Akababa commented Feb 4, 2020

Not sure about that, where can I find more info on compatibility? I think it only relies on torch.where (introduced <= 1.0.0) and tensors of dtype torch.bool (introduced in 1.2.0). Does the None (newaxis) slicing introduce compatibility issues?

If we want to maintain compatibility with 1.0.0, I think we can use torch.uint8 instead of torch.bool.

@nikita-smetanin
Copy link

nikita-smetanin commented Mar 25, 2020

Hi, I'd recommend to make the following changes:

  1. Keep the original shapes of bias buffer (because otherwise it breaks loading of already trained models) and make dtype equal to torch.uint8, so it'd be compatible with pytorch 1.0.0 as no torch.bool type available.
    self.register_buffer("bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx))
  2. Keep -1e4 constant in a buffer to reduce allocations on each _attn call and make it works automatically with different devices (CPU and CUDA):
    self.register_buffer("masked_bias", torch.tensor(-1e4))
  3. Keep b = self.bias[:, :, ns - nd : ns, :ns] line as bias buffer have the original shape now
  4. So the where statement should look like w = torch.where(b, w, self.masked_bias)

As a result, overall speedup will be at 10-15% here as I measured, and the code should be 100% compatible with pytorch 1.0.0

@patrickvonplaten
Copy link
Contributor

Hi @Akababa,

Thanks for the PR. I think this is a great change. I checked and it does lead to a significant speed-up :-)

Could you fix the tests and I think then we can merge (see https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)

  1. You should fetch the master branch and rebase your branch on top of it.
  2. Make sure to run make style in the root folder before pushing to pass the "check_code_quality" test.

Akababa added 4 commits March 29, 2020 15:02
Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance.
@patrickvonplaten
Copy link
Contributor

Great work @Akababa - this looks good to me!

@LysandreJik @thomwolf - could you check and merge?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@patrickvonplaten patrickvonplaten merged commit 05deb52 into huggingface:master Apr 7, 2020
@patrickvonplaten
Copy link
Contributor

Checked slow hardcoded GPT2 tests and it looks all good!

@patrickvonplaten patrickvonplaten mentioned this pull request May 14, 2020
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants