-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Attention] Flash Attention 3 - fp8 #14570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] Flash Attention 3 - fp8 #14570
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]>
bc909f9 to
2b985ed
Compare
Signed-off-by: Mickael Seznec <[email protected]>
|
CI failing because @robertgshaw2-redhat any idea how should I fix? Just rename in run-tpu-test.sh? (@NickLucche you moved the file) |
|
This is a known issue, PR addressing it here #13898. It won't block your PR. |
|
I see there's some other problem with building the image, but likely CI just needs another spin |
|
@mickaelseznec apologies for the delay, vllm-project/flash-attention#50 (review) has been merged, you can now point to vllm_flash_attn We will need to populate the sccache on the server to get it through the CI, I can help with this once the tag is updated 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Looks clean 😄, ill approve once we can get it updated to use vllm_flash_attn, added a couple comments
tests/kernels/test_flash_attn.py
Outdated
|
|
||
| q_descale = q_scale.expand((num_seqs, num_kv_heads)) | ||
| k_descale = k_scale.expand((num_seqs, num_kv_heads)) | ||
| v_descale = v_scale.expand((num_seqs, num_kv_heads)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we maybe test per-head scales here too?, i.e. also test with non-zero strides
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add tests here, but these type of scaling isn't supported by vLLM for the moment. I believe that whenever we add support for it, we can add tests as well.
Besides, there's already a combination of 9k tests in here, I don't want to make the duration explode if it's not 100% needed :D
| "Cannot use FlashAttention-2 backend for dtype other than " | ||
| "torch.float16 or torch.bfloat16.") | ||
| target_backend = _Backend.XFORMERS | ||
| elif kv_cache_dtype is not None and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should keep this check but restrict it to FA2, i.e. check get_flash_attn_version() != 2 (get_flash_attn_version() is in vllm/attention/backends/utils.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that this might be improved, but I can't directly import get_flash_attn_version because of circular dependency.
Do you prefer if I move that function in another file? vllm/attention/backends/versions.py for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you prefer if I move that function in another file? vllm/attention/backends/versions.py for example?
sure, maybe move it to:
vllm/attention/utils/fa_support.py
for now, since there is is_flash_attn_mla_supported() function that may come in #14258 so this could be a spot for both of those
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to move it to vllm/fa_utils.py because of how vllm/attention/__init__.py imports a bunch of stuff for convenience.
Signed-off-by: Mickael Seznec <[email protected]>
this is needed to avoid circular dependencies now that we want to get the flash_attn_version directly in platforms/cuda.py to check if fp8 flash_attn is actually available Signed-off-by: Mickael Seznec <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Mickael Seznec <[email protected]>
|
apologies for the delay, the CI should be working now. There appears to be failing kernel tests |
increase flash_attention unit test tolerance Signed-off-by: Mickael Seznec <[email protected]>
|
Hi, it seems there are no new nightly wheels since this PR. Is there anything wrong? @LucasWilkinson |
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Mickael Seznec <[email protected]> Signed-off-by: Mu Huai <[email protected]>
This PR add support for FP8 KV cache with FlashAttention3 (related PR in flash-attn here) cc @LucasWilkinson Please do not merge this PR as long as it's not referencing vllm-project/flash-attention yet.
FlashAttention (contrary to FlashInfer) does attention with all Q, K and V in FP8.
The performance is usually better than FlashInfer FP8 KV and FlashAttention 3 with bf16.
I added support for v0 and v1 + some unit testing.
Note that I've added a trick for checkpoints not providing q_scale and reuse the k_scale (with is something TRTLLM does fwiw).
Also: I added a small QoS improvement when debugging v1: workers send back their traceback when they raise an exception.