Skip to content

Conversation

@DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Jul 4, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Currently, models with head_size in [80, 112, 120] are supported in V0 but not in V1. The reason is that the FlexAttention backend is marked as only supporting head_size in [32, 64, 80, 96, 112, 120, 128, 192, 256] when it should actually support all head sizes.

This PR updates the FlexAttention backend to allow any head size, and updates the unsupported head size message of other backends to ask users to set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION instead of VLLM_USE_V1=0.

Based on the discussion in: https://vllm-dev.slack.com/archives/C087RA55P0D/p1750313174942849

cc @mgoin @WoosukKwon

FIX #14524

Test Plan

Re-enable affected models in multimodal and quantization tests.

Test Result

(Optional) Documentation Update

Removed note about head_size = 80 not being supported.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 4, 2025
@github-actions
Copy link

github-actions bot commented Jul 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) v1 labels Jul 4, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @DarkLight1337, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility of the vLLM library by enabling the FlexAttention backend to support any head size. This change allows previously unsupported models to run in V1 mode and updates error messages to guide users towards using FlexAttention when necessary. The PR also includes updates to the documentation and test configurations to reflect these changes.

Highlights

  • FlexAttention Backend: Updated the FlexAttention backend to allow any head size, removing the previous restriction to specific sizes.
  • Error Messaging: Modified the unsupported head size message for other backends to guide users to set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION instead of VLLM_USE_V1=0.
  • Model Support: Re-enabled affected models in multimodal and quantization tests by leveraging the FlexAttention backend for models with head sizes like 80.
  • Documentation: Removed the note in supported_models.md about head_size = 80 not being supported.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively implements support for arbitrary head sizes in the FlexAttention backend, which is a significant improvement for model compatibility. The changes are well-structured, with the introduction of a validate_head_size static method across various attention backends. This centralizes validation logic and provides clear guidance to users about using the FlexAttention backend for unsupported head sizes. Documentation, examples, and tests have been updated appropriately to reflect these changes, ensuring correctness and consistency. The refactoring improves code maintainability and clarity without introducing any new issues.

Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the out-of-date documentation

REQUIRES_V0_MODELS = [
# V1 Test: no way to fall back for head_dim = 80
# https://github.com/vllm-project/vllm/issues/14524
"qwen_vl",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model actually has head size 128

@DarkLight1337
Copy link
Member Author

FlexAttention doesn't seem to work in the CI environment:

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588] EngineCore encountered a fatal error.

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588] Traceback (most recent call last):

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 579, in run_engine_core

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     engine_core.run_busy_loop()

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 606, in run_busy_loop

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     self._process_engine_step()

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 631, in _process_engine_step

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     outputs, model_executed = self.step_fn()

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                               ^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 235, in step

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     model_output = self.execute_model(scheduler_output)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 221, in execute_model

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     raise err

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 212, in execute_model

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self.model_executor.execute_model(scheduler_output)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 87, in execute_model

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     output = self.collective_rpc("execute_model",

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     answer = run_method(self.driver_worker, method, args, kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/__init__.py", line 2736, in run_method

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return func(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return func(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 308, in execute_model

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     output = self.model_runner.execute_model(scheduler_output,

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return func(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1370, in execute_model

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     model_output = self.model(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                    ^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._call_impl(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return forward_call(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/stablelm.py", line 327, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     hidden_states = self.model(input_ids, positions, intermediate_tensors,

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._call_impl(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return forward_call(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/stablelm.py", line 251, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     hidden_states, residual = layer(positions, hidden_states)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._call_impl(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return forward_call(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/stablelm.py", line 189, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     hidden_states = self.self_attn(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                     ^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._call_impl(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return forward_call(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/stablelm.py", line 155, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     attn_output = self.attn(q, k, v)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                   ^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._call_impl(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return forward_call(*args, **kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 243, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     torch.ops.vllm.unified_attention_with_output(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in __call__

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._op(*args, **(kwargs or {}))

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 451, in unified_attention_with_output

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     self.impl.forward(self,

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flex_attention.py", line 461, in forward

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     out = flex_attention_compiled(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]           ^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 663, in _fn

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 760, in _compile_fx_inner

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     raise InductorError(e, currentframe()).with_traceback(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 745, in _compile_fx_inner

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     mb_compiled_graph = fx_codegen_and_compile(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                         ^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 1295, in fx_codegen_and_compile

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 1197, in codegen_and_compile

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     compiled_fn = graph.compile_to_module().call

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                   ^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py", line 2083, in compile_to_module

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     return self._compile_to_module()

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]            ^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py", line 2130, in _compile_to_module

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     mod = PyCodeCache.load_by_key_path(

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 2747, in load_by_key_path

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     mod = _reload_python_module(key, path)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/compile_tasks.py", line 36, in _reload_python_module

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     exec(code, mod.__dict__, mod.__dict__)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/tmp/torchinductor_root/3j/c3js4bfn6s2coox5is67eklodd33fedrkolv7aflrp2rcfrmj3an.py", line 42, in <module>

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/async_compile.py", line 346, in triton

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     kernel.precompile(warm_cache_only=False)

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 277, in precompile

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     self._make_launchers()

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 439, in _make_launchers

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588]     raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")

[2025-07-04T05:34:37Z] ERROR 07-03 22:34:37 [core.py:588] torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 135168, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Jul 4, 2025

cc @youkaichao @drisspg I think it's related to pytorch/pytorch#133254

Any ideas on how to fix/work around this in the meantime?

Signed-off-by: DarkLight1337 <[email protected]>
Comment on lines 170 to 172
# As of this writing, head_size=80 is only supported by FlexAttention in V1
if model_name.endswith("-2b"):
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
Copy link
Member

@Isotr0py Isotr0py Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check in get_attn_backend_cls to automatically fall back to FlexAttn?

vllm/vllm/platforms/cuda.py

Lines 191 to 194 in ffe00ef

@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:

The check should be very simple to fit FlashInfer and FlashAttn:

if head_size % 32:
    return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not complete though. I am not sure whether we can import the attention backend classes to call validate_head_size beforehand.

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look at the files, it's very likely that importing these files will initialize CUDA. Perhaps we can move validate_head_size to a separate file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon how should we solve this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, let's see if importing the metadata directly can pass the tests as is first...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we can safely import the attention backends

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isotr0py
Copy link
Member

Isotr0py commented Jul 4, 2025

FlexAttention doesn't seem to work in the CI environment:
RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 135168, Hardware limit: 101376. Reducing block sizes or num_stages may help.

The default FlexAttention config can run out of shared memory on L4, perhaps we need to reduce the M/N block_size when running FlexAttn like the fp32 PR:

# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 64,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
out = flex_attention_compiled(
query,
key_cache,
value_cache,
attn_metadata.score_mod,
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
)

The default FlexAttn config is here: https://github.com/pytorch/pytorch/blob/47c8aa80905f8cc9ea5488b5a3a209bee62d8409/torch/_inductor/kernel/flex_attention.py#L595-L618

@DarkLight1337 DarkLight1337 marked this pull request as ready for review July 5, 2025 11:35
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member Author

@DarkLight1337 DarkLight1337 marked this pull request as draft July 5, 2025 14:47
@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Jul 5, 2025

It appears that none of the V1 attention backends currently support blocksparse attention which is used in Phi-3-Small. I'll continue to skip that test until FlexAttention supports it then.

Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337 DarkLight1337 marked this pull request as ready for review July 5, 2025 15:51
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member Author

Also cc @LucasWilkinson

@DarkLight1337
Copy link
Member Author

All tests pass apart from the ones that are already failing on main

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment on lines +46 to +47
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if @drisspg could confirm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just seeing this now, yeah we support any head size. You will start to run out of shared memory around 256 + w/ some tricks you can get 512/576 working but for all intents and purposes and accept for some vision models this is true

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we calculate the amount of shared memory required? The current way to reduce BLOCK_M and BLOCK_N in L463 is very rough.

@vllm-bot vllm-bot merged commit 9fb52e5 into vllm-project:main Jul 6, 2025
72 of 74 checks passed
@DarkLight1337 DarkLight1337 deleted the v1-head-size branch July 6, 2025 16:54
@jeejeelee jeejeelee mentioned this pull request Jul 7, 2025
4 tasks
vaibhavjainwiz pushed a commit to red-hat-data-services/vllm that referenced this pull request Jul 15, 2025
Sync to v0.9.2 + remove libsodium + [fix
cachetokeziner](neuralmagic/nm-vllm-ent@1423512)

git log:
```
commit 7b94527 (HEAD -> sync-v0.9.2, nm-fork/sync-v0.9.2)
Merge: 1423512 d07be8a
Author: Selbi Nuryyeva <[email protected]>
Date:   Fri Jul 11 07:03:51 2025 -0400

    Merge remote-tracking branch 'nm-fork/main' into sync-v0.9.2

commit 1423512
Author: Isotr0py <[email protected]>
Date:   Mon Jun 30 18:16:16 2025 +0800

    disable using CacheTokenizer for transformers >= 4.53.0
    
    fixes vllm-project#20224
    
    addendum to vllm-project#20244

commit d07be8a (nm-fork/main, nm-fork/HEAD)
Merge: bbccdbe 02152ad
Author: Daniele <[email protected]>
Date:   Wed Jul 9 15:18:56 2025 +0200

    Dockerfile*.ubi: remove libsodium (opendatahub-io#245)
    
    It's not needed anymore
    
    https://issues.redhat.com/browse/INFERENG-848

commit 7dd12da
Merge: bbccdbe a5dd03c
Author: Selbi Nuryyeva <[email protected]>
Date:   Tue Jul 8 10:08:37 2025 -0400

    Merge branch 'v0.9.2-upstream' into sync-v0.9.2

commit a5dd03c (tag: v0.9.2rc2, tag: v0.9.2, upstream/releases/v0.9.2, v0.9.2-upstream, upstream-v0.9.2)
Author: simon-mo <[email protected]>
Date:   Sun Jul 6 14:02:36 2025 -0700

    Revert "[V0 deprecation] Remove V0 CPU/XPU/TPU backends (vllm-project#20412)"
    
    This reverts commit e202dd2.

commit c18b3b8
Author: Cyrus Leung <[email protected]>
Date:   Mon Jul 7 05:01:48 2025 +0800

    [Bugfix] Add `use_cross_encoder` flag to use correct activation in `ClassifierPooler` (vllm-project#20527)
    
    Signed-off-by: DarkLight1337 <[email protected]>

commit 9528e3a
Author: Woosuk Kwon <[email protected]>
Date:   Sun Jul 6 12:44:52 2025 -0700

    [BugFix][Spec Decode] Fix spec token ids in model runner (vllm-project#20530)
    
    Signed-off-by: Woosuk Kwon <[email protected]>

commit 9fb52e5
Author: Cyrus Leung <[email protected]>
Date:   Mon Jul 7 00:54:36 2025 +0800

    [V1] Support any head size for FlexAttention backend (vllm-project#20467)
    
    Signed-off-by: DarkLight1337 <[email protected]>
```

Test:
CUDA: https://github.com/neuralmagic/nm-cicd/actions/runs/16218517666
ROCM: https://github.com/neuralmagic/nm-cicd/actions/runs/16218578391
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: [V1] Support Fallback For Unsupported Head Dim on FA

7 participants