-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
Feature request
Addition of PyTorch SDPA and Flash Attention 2 to the Wav2Vec2 modelling code.
Motivation
Wav2Vec2 and its derived models remain some of the most popular speech recognition and audio classification models in the library. However, only one attention implementation is available to users: the slowest and most memory consuming "eager" mode. We should update the modelling code to provide two newer attention implementations: SDPA and FA2, both of which are faster and more memory efficient.
Since Wav2Vec2 copies its attention from BART, and SDPA & FA2 were added for BART in this PR, this should be quite a straightforward PR, mostly copying out the logic from the BART PR and pasting it into Wav2Vec2. We should then be sure to add two fast tests (one for each of SDPA and FA2), e.g. in the style of the test here, and two slow integration tests, e.g. in the style of the tests here.
Your contribution
Want to take this one @kamilakesbi?