-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Fix flash_attention.py: wrong argument passing for attn_implementation #41347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Oh good catch there is indeed the wrong kwarg being passed on! I would prefer if you can change the name of the kwarg in |
Thank you. I have changed that. Please check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also fix the docs naming then at
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 567 to 568 in 0452f28
implementation (`str`, *optional*): | |
The attention implementation to use. If None, will default to the one based on the environment. |
LGTM otherwise 🤗
The name of the attn type argument for `_flash_attention_forward()` should be `implementation`, instead of `attn_implementation` which currently uses in the function call. This would result in wrong type specification.
Thank you! Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! 🤗
https://github.com/huggingface/transformers/blob/update-from-pretrained/src/transformers/integrations/hub_kernels.py#L214-L214 is where we use |
So would it be better if I rollback to the fix that simply change |
The link is broken, I think Arthur meant transformers/src/transformers/integrations/hub_kernels.py Lines 214 to 215 in caa14e7
Yes, we should also change the kwarg there, totally forgot there, e.g. Edit: Still loads the correct kernel implementation, checked with |
What does this PR do?
The name of the attn type argument for
_flash_attention_forward()
should beimplementation
, instead ofattn_implementation
which currently uses in the function call. This would result in wrong type specification.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
CC
@vasqu @ArthurZucker @Cyrilvallez