Skip to content

SPDA/FA2 Attention for the Wav2Vec2 Family of Models #30073

@sanchit-gandhi

Description

@sanchit-gandhi

Feature request

Addition of PyTorch SDPA and Flash Attention 2 to the Wav2Vec2 modelling code.

Motivation

Wav2Vec2 and its derived models remain some of the most popular speech recognition and audio classification models in the library. However, only one attention implementation is available to users: the slowest and most memory consuming "eager" mode. We should update the modelling code to provide two newer attention implementations: SDPA and FA2, both of which are faster and more memory efficient.

Since Wav2Vec2 copies its attention from BART, and SDPA & FA2 were added for BART in this PR, this should be quite a straightforward PR, mostly copying out the logic from the BART PR and pasting it into Wav2Vec2. We should then be sure to add two fast tests (one for each of SDPA and FA2), e.g. in the style of the test here, and two slow integration tests, e.g. in the style of the tests here.

Your contribution

Want to take this one @kamilakesbi?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions