Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1c03c85
Support compile on any module (towards compilable vision transformer)
Lucaskabela Aug 19, 2025
af86a0d
Changes for benchmarking
Lucaskabela Aug 26, 2025
319d584
Minimize to just qwen2_5 again for consistency
Lucaskabela Aug 26, 2025
1be4959
Working version
Lucaskabela Sep 9, 2025
9850035
Working version2
Lucaskabela Sep 9, 2025
a692789
Tidy changes to qwen
Lucaskabela Sep 11, 2025
03e503f
Rebase and tidy
Lucaskabela Sep 11, 2025
408c704
Rebase on main, more benchmarking
Lucaskabela Sep 16, 2025
2c3374f
Mark seqlen dynamic too
Lucaskabela Sep 25, 2025
f58d94b
Remove errenously added import
Lucaskabela Sep 25, 2025
bdfad3c
Move custom ops to be granular in helper file
Lucaskabela Sep 30, 2025
578ec11
fix formatting
Lucaskabela Oct 7, 2025
75dd21d
Format all files
Lucaskabela Oct 7, 2025
aedf9c3
Fix import vs main
Lucaskabela Oct 10, 2025
40644fe
Fix formatting locally
Lucaskabela Oct 10, 2025
fe6967f
Rebase and fix tests
Lucaskabela Oct 13, 2025
6bddb4d
Fix failing test by avoid specialization
Lucaskabela Oct 15, 2025
7ca7fe9
Fix comments/tests
Lucaskabela Oct 15, 2025
741e395
Fix import
Lucaskabela Oct 15, 2025
edf0b64
Fix failing eagle test and add todos
Lucaskabela Oct 16, 2025
fcb13d7
Fix collect_env added
Lucaskabela Oct 20, 2025
c2d155a
Fix imports for vit_attn_wrappers.py
Lucaskabela Oct 20, 2025
ab6d2e1
Add unit test
Lucaskabela Oct 27, 2025
4927e55
Merge branch 'main' into lucaskabela/compile_nn_module
Lucaskabela Oct 28, 2025
309fb19
Merge branch 'main' into lucaskabela/compile_nn_module
ywang96 Oct 28, 2025
5c7829e
Merge branch 'main' into lucaskabela/compile_nn_module
Lucaskabela Oct 28, 2025
b035b09
Merge branch 'main' into lucaskabela/compile_nn_module
Lucaskabela Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/compile/test_multimodal_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.compilation.counter import compilation_counter
from vllm.config.compilation import CompilationMode


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"""Test that Qwen2.5-VL vision submodules are compiled.

This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
for compilation by checking that num_models_seen increases by at least 3.
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

with (
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
# (one for each layer) - in the future, we should fix vLLM compilation
# logic to handle this case and only compile the Vision submodules once
# and reuse the compiled code for all layers
# See https://github.com/vllm-project/vllm/issues/27590
compilation_counter.expect(num_models_seen=35),
vllm_runner(
"Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=2048,
gpu_memory_utilization=0.7,
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
) as _,
):
pass
125 changes: 125 additions & 0 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""

import einops
import torch

from vllm.utils.torch_utils import direct_register_custom_op


def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer


def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)


direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)


def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)


def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer


def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)


direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)


def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
)
60 changes: 56 additions & 4 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.config import (
CompilationMode,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
Expand Down Expand Up @@ -74,6 +79,21 @@ def support_torch_compile(
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
Comment on lines +82 to +86

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg is used with dynamic_arg_dims for Qwen2_5_VisionBlock. This overload does not provide a dynamic_arg_dims arg and my type checker complains about this. Should this be merged into overload above?

Copy link
Contributor Author

@Lucaskabela Lucaskabela Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ProExpertProg @zou3519 @ywang96 can y'all have another look at this? I needed to add someway to mark a particular dimension as not only dynamic, but unbacked (so I did that here); without this, compile 0/1 specialized during tracing for when dumy_multimodal_input in gpu_runner returns 1

I tried to follow the convention of the other args for support_torch_compile



@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(cls: _T) -> _T: ...

Expand All @@ -82,6 +102,7 @@ def support_torch_compile(
cls: _T | None = None,
*,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T] | _T:
"""
Expand Down Expand Up @@ -135,6 +156,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.

`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
"""

def cls_decorator_helper(cls: _T) -> _T:
Expand Down Expand Up @@ -172,7 +198,9 @@ def cls_decorator_helper(cls: _T) -> _T:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
)

if cls is not None:
# use `support_torch_compile` as a decorator without arguments
Expand Down Expand Up @@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> _T:
"""
Expand All @@ -230,8 +259,22 @@ def _support_torch_compile(

setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
if vllm_config is None:
vllm_config = get_current_vllm_config()

# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)

self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
Expand Down Expand Up @@ -344,6 +387,15 @@ def __call__(self, *args, **kwargs):
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}."
)
if mark_unbacked_dims:
for k, dims in mark_unbacked_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s", self.original_code_object)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,8 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:

from vllm.compilation.backends import VllmBackend

# TODO[@lucaskabela]: See if we can forward prefix
# https://github.com/vllm-project/vllm/issues/27045
return VllmBackend(vllm_config)

def post_init_cudagraph_sizes(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Expand Down Expand Up @@ -759,7 +760,8 @@ def _process_image_input(
assert grid_thw.ndim == 2

pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
with set_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
Expand All @@ -779,7 +781,8 @@ def _process_video_input(
assert grid_thw.ndim == 2

pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
with set_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
Expand Down Expand Up @@ -839,6 +842,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> str | None:

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
thinker_config: Qwen2_5OmniThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config
)
Expand Down
Loading