-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[SFTTrainer] Flash attention support for SFTTrainer
#656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
SFTTrainer] Flash attention support for SFTTrainerSFTTrainer] Flash attention support for SFTTrainer
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Are trl users using padding? |
|
adding |
| ) | ||
| from transformers.trainer_callback import TrainerCallback | ||
| from transformers.trainer_utils import EvalPrediction | ||
| from transformers.utils import ContextManagers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hahahah nice!
SFTTrainer] Flash attention support for SFTTrainerSFTTrainer] Flash attention support for SFTTrainer
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
Closing for now as huggingface/transformers#25598 might be merged |
Related to huggingface/transformers#25265
Users can easily benefit from flash attention, this PR adds a new argument in
SFTTrainerto take care of that and properly document and raise errors when relevant.This leads to some interesting speedups and memory saving that I will detail after doing some experiments, hence putting this PR as draft for now
Only available if you use pytorch nightlies and use
packing=Trueas SDPA + flash attention does not support paddingcc @lvwerra @fxmarty @vwxyzjn