-
Notifications
You must be signed in to change notification settings - Fork 31.2k
🚨 [Clip] Fix masking and enable flash attention on all model types
#41750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @yonigozlan when you come across models like these in the vision refactors |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't see that we're changing only the text model. LGTM, as long as the slow tests are passing
|
run-slow: clip |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip'] |
|
@molbap @zucchini-nlp I changed a few things to align the kwargs with our modern practices, i.e. see 764e63f This makes kwarg easy to type properly, otherwise we probably need to type it as FA kwargs 🤔 |
Clip] Fix masking and enable flash attention on all model typesClip] Fix masking and enable flash attention on all model types
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better with the typing, thanks!
|
CI has issues today, will probably check tomorrow again (and propogate the changes to metaclip 2 + mlcd) and merge then |
|
run-slow: clip, metaclip_2, mlcd, llava |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip', 'models/llava', 'models/metaclip_2', 'models/mlcd'] |
|
run-slow: clip, metaclip_2, mlcd, llava |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip', 'models/llava', 'models/metaclip_2', 'models/mlcd'] |
|
Yea, the CI is not having a good day :D locally all the relevant tests passed, especially the integration tests - checking tomorrow |
|
run-slow: clip, metaclip_2, mlcd, llava |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip', 'models/llava', 'models/metaclip_2', 'models/mlcd'] |
|
run-slow: clip, metaclip_2, mlcd, llava |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip', 'models/llava', 'models/metaclip_2', 'models/mlcd'] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: clip, metaclip_2, mlcd |
|
run-slow: clip, metaclip_2, mlcd, llava |
|
This comment contains run-slow, running the specified jobs: models: ['models/clip', 'models/llava', 'models/metaclip_2', 'models/mlcd'] |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanlks
| ) | ||
|
|
||
| attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() | ||
| attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we use -1 for the batchsize as text can be ragged but not an issue
…uggingface#41750) * fix * make kwargs fully passed and adjust with outputs xxx * propogate metaclip 2 * propogate mlcd and fix test * style * fix repo consistency, need to add ignore rules as those are building blocks * style * oops * fix mlcd
Clip used old mask APIs leading to a confused usage:
^ works only for interfaces with support for 4D masks which disabled FA usage in general.
This PR now correctly changes this to the new API which handles padding automatically. We have to additionally pass the
is_causalkwarg to dynamically switch between modality types (text == causal, image == full). This is only enabled through recent PRs (fa #39707, sdpa #41692).Closes #41673
Fixes #41668