Skip to content

Conversation

@mickaelseznec
Copy link
Contributor

@mickaelseznec mickaelseznec commented Mar 10, 2025

This PR add support for FP8 KV cache with FlashAttention3 (related PR in flash-attn here) cc @LucasWilkinson Please do not merge this PR as long as it's not referencing vllm-project/flash-attention yet.

FlashAttention (contrary to FlashInfer) does attention with all Q, K and V in FP8.
The performance is usually better than FlashInfer FP8 KV and FlashAttention 3 with bf16.

I added support for v0 and v1 + some unit testing.

Note that I've added a trick for checkpoints not providing q_scale and reuse the k_scale (with is something TRTLLM does fwiw).

Also: I added a small QoS improvement when debugging v1: workers send back their traceback when they raise an exception.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mickaelseznec mickaelseznec force-pushed the mseznec/flash-attention-fp8 branch from bc909f9 to 2b985ed Compare March 10, 2025 15:41
Signed-off-by: Mickael Seznec <[email protected]>
@mickaelseznec
Copy link
Contributor Author

CI failing because vllm/tests/entrypoints/openai/test_accuracy.py from here doesn't exist.

@robertgshaw2-redhat any idea how should I fix? Just rename in run-tpu-test.sh? (@NickLucche you moved the file)

@NickLucche
Copy link
Collaborator

This is a known issue, PR addressing it here #13898. It won't block your PR.

@NickLucche
Copy link
Collaborator

I see there's some other problem with building the image, but likely CI just needs another spin

@LucasWilkinson
Copy link
Collaborator

@mickaelseznec apologies for the delay, vllm-project/flash-attention#50 (review) has been merged, you can now point to vllm_flash_attn

We will need to populate the sccache on the server to get it through the CI, I can help with this once the tag is updated 👍

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Looks clean 😄, ill approve once we can get it updated to use vllm_flash_attn, added a couple comments


q_descale = q_scale.expand((num_seqs, num_kv_heads))
k_descale = k_scale.expand((num_seqs, num_kv_heads))
v_descale = v_scale.expand((num_seqs, num_kv_heads))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we maybe test per-head scales here too?, i.e. also test with non-zero strides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add tests here, but these type of scaling isn't supported by vLLM for the moment. I believe that whenever we add support for it, we can add tests as well.
Besides, there's already a combination of 9k tests in here, I don't want to make the duration explode if it's not 100% needed :D

"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should keep this check but restrict it to FA2, i.e. check get_flash_attn_version() != 2 (get_flash_attn_version() is in vllm/attention/backends/utils.py)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that this might be improved, but I can't directly import get_flash_attn_version because of circular dependency.

Do you prefer if I move that function in another file? vllm/attention/backends/versions.py for example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you prefer if I move that function in another file? vllm/attention/backends/versions.py for example?

sure, maybe move it to:

vllm/attention/utils/fa_support.py

for now, since there is is_flash_attn_mla_supported() function that may come in #14258 so this could be a spot for both of those

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move it to vllm/fa_utils.py because of how vllm/attention/__init__.py imports a bunch of stuff for convenience.

this is needed to avoid circular dependencies now that we want to get
the flash_attn_version directly in platforms/cuda.py to check if fp8
flash_attn is actually available

Signed-off-by: Mickael Seznec <[email protected]>
@mergify
Copy link

mergify bot commented Mar 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mickaelseznec.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 14, 2025
@mergify mergify bot removed the needs-rebase label Mar 14, 2025
Signed-off-by: Mickael Seznec <[email protected]>
@LucasWilkinson
Copy link
Collaborator

apologies for the delay, the CI should be working now. There appears to be failing kernel tests

increase flash_attention unit test tolerance

Signed-off-by: Mickael Seznec <[email protected]>
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2025
@LucasWilkinson LucasWilkinson changed the title Mseznec/flash attention fp8 [Attention] Flash Attention 3 - fp8 Mar 19, 2025
@LucasWilkinson LucasWilkinson enabled auto-merge (squash) March 19, 2025 20:56
@LucasWilkinson LucasWilkinson merged commit a597a57 into vllm-project:main Mar 20, 2025
48 checks passed
@JaheimLee
Copy link

JaheimLee commented Mar 20, 2025

Hi, it seems there are no new nightly wheels since this PR. Is there anything wrong? @LucasWilkinson

erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Mickael Seznec <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants