- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Optimize causal mask using torch.where #2715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Codecov Report
 @@            Coverage Diff             @@
##           master    #2715      +/-   ##
==========================================
- Coverage    77.8%   77.79%   -0.02%     
==========================================
  Files         100      100              
  Lines       17051    17052       +1     
==========================================
- Hits        13267    13266       -1     
- Misses       3784     3786       +2
 Continue to review full report at Codecov. 
 | 
| Thanks – what's the PyTorch compatibility on this? | 
| Not sure about that, where can I find more info on compatibility? I think it only relies on torch.where (introduced <= 1.0.0) and tensors of dtype torch.bool (introduced in 1.2.0). Does the None (newaxis) slicing introduce compatibility issues? If we want to maintain compatibility with 1.0.0, I think we can use torch.uint8 instead of torch.bool. | 
| Hi, I'd recommend to make the following changes: 
 As a result, overall speedup will be at 10-15% here as I measured, and the code should be 100% compatible with pytorch 1.0.0 | 
| Hi @Akababa, Thanks for the PR. I think this is a great change. I checked and it does lead to a significant speed-up :-) Could you fix the tests and I think then we can merge (see https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md) 
 | 
Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance.
| Great work @Akababa - this looks good to me! @LysandreJik @thomwolf - could you check and merge? | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| Checked slow hardcoded GPT2 tests and it looks all good! | 
Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance.