Skip to content

Conversation

TKONIY
Copy link
Contributor

@TKONIY TKONIY commented Oct 4, 2025

What does this PR do?

The name of the attn type argument for _flash_attention_forward() should be implementation, instead of attn_implementation which currently uses in the function call. This would result in wrong type specification.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

CC

@vasqu @ArthurZucker @Cyrilvallez

@Cyrilvallez
Copy link
Member

Oh good catch there is indeed the wrong kwarg being passed on! I would prefer if you can change the name of the kwarg in _flash_attention_forward though, it would be more explicit!

@TKONIY
Copy link
Contributor Author

TKONIY commented Oct 6, 2025

Oh good catch there is indeed the wrong kwarg being passed on! I would prefer if you can change the name of the kwarg in _flash_attention_forward though, it would be more explicit!

Thank you. I have changed that. Please check.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also fix the docs naming then at

implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.

LGTM otherwise 🤗

TKONIY added 3 commits October 6, 2025 21:28
The name of the attn type argument for `_flash_attention_forward()` should be `implementation`, instead of `attn_implementation` which currently uses in the function call. This would result in wrong type specification.
@TKONIY
Copy link
Contributor Author

TKONIY commented Oct 6, 2025

Can you also fix the docs naming then at

implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.

LGTM otherwise 🤗

Thank you! Done!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! 🤗

@Cyrilvallez Cyrilvallez merged commit ae60c77 into huggingface:main Oct 6, 2025
5 of 8 checks passed
@ArthurZucker
Copy link
Collaborator

https://github.com/huggingface/transformers/blob/update-from-pretrained/src/transformers/integrations/hub_kernels.py#L214-L214 is where we use implementation. If you do this it won't fallback to kernels, we need to make sure we use the one passed in load and register

@TKONIY
Copy link
Contributor Author

TKONIY commented Oct 6, 2025

https://github.com/huggingface/transformers/blob/update-from-pretrained/src/transformers/integrations/hub_kernels.py#L214-L214 is where we use implementation. If you do this it won't fallback to kernels, we need to make sure we use the one passed in load and register

So would it be better if I rollback to the fix that simply change attn_implementation= to implementation=

@vasqu
Copy link
Contributor

vasqu commented Oct 6, 2025

The link is broken, I think Arthur meant

kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel, force_import=True)

Yes, we should also change the kwarg there, totally forgot there, e.g. kernel_function = partial(attention_wrapper, attn_implementation=kernel) . But it's a bit messier tbh and I don't think we actually use the kwarg much at all except on first call (which would happen if someone does something custom with fa interface we have) --> the forced lazy import should load the correct kernel (checking in a second) and as we already loaded it, we never change it again there.

Edit: Still loads the correct kernel implementation, checked with kernels-community/flash-attn3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants