-
Notifications
You must be signed in to change notification settings - Fork 31.2k
[Whisper, Bart, MBart] Add Flash Attention 2 #27203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing piece of work! 🔥
Main comment is about the tests - I think some might be indexing on outputs when it should be using outputs_fa
tests/test_modeling_common.py
Outdated
|
|
||
| output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) | ||
| logits = output.hidden_states[-1] | ||
| self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This... isn't that close. I can see it's the tolerance used elsewhere but seems like quite a big difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, Flash attention leads to very much different results though. I think 0.04 is good enough tbh
| input_dtype = query_states.dtype | ||
| if input_dtype == torch.float32: | ||
| # Handle the case where the model is quantized | ||
| if hasattr(self.config, "_pre_quantization_dtype"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to have access to the config here
Co-authored-by: amyeroberts <[email protected]>
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. I appreciate the # Copied from statements which make the code simpler to review.
Very clean, it's ok for me to merge
| super().__init__() | ||
| self.embed_dim = config.d_model | ||
| self.self_attn = BartAttention( | ||
| attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should eventually move this to an enum to be clenaer (out of scope for this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(brainstorming, still out of scope) It would be cleaner to eventually have the config return the appropriate attention name for all models:
self.self_attn = BART_ATTENTION_CLASSES[config.attention_type](
...
)with
class PreTrainedConfig():
...
@property
def attention_type(self):
return AttentionTypes.FA2 if getattr(self, "_flash_attn_2_enabled", False) else AttentionTypes.DEFAULTThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah attention_type as a property is a good idea I think! We should then probably also allow users to change it even after the model was loaded
tests/test_modeling_common.py
Outdated
| # make sure that all models have at least 40 position ids | ||
| if hasattr(config, "max_position_embeddings"): | ||
| config.max_position_embeddings = 40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why that minimum?
|
Ok ran some more tests and it should be good now. I'm getting some flaky behavior with the flash attention tests on my RTX 4090 (especially extreme for Whisper). We should maybe think about how we can make them more robust now that we've added some more models (cc @younesbelkada) |
* add whisper fa2 * correct * change all * correct * correct * fix more * fix more * fix more * fix more * fix more * fix more * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * fix more * fix more * fix more * fix more * fix more --------- Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
This PR adds Flash Attention for Whisper, Bart & MBart.
Whisper depends on Bart and MBart quite a bit for Flash Attention like 20+ other model architectures.
As this is the first PR that adds Flash Attention 2 to a encoder-decoder model, I wanted to make sure it's done for the two template models (Bart and MBart) as well so that Whisper (and all other encoder-decoder models that follow) don't loose their "# Copied from" statements.
Note that while this PR changes 27 files, only 4 files are really relevant to review because all other files are just consequences of the "# Copied from mechanism":
The following there files fully implement Flash Attention 2:
The test files is restructured so that Flash Attention 2 tests can nicely run for different kinds of models (audio & nlp as well as decoder-only and encoder-decoder).
I ran the following tests to make sure everything works as expected:
as well as:
All tests pass that also pass on "main". The only failures are related to disk offloading which should be fixed in: #27204
There are some "error not raised" failures for flash attn and mistral, but they are also present in "main" and seem to be related to this PR: #27125 (cc @younesbelkada), I'd suggest to also fix those in another PR.
Other CI test failures are unrelated.