Skip to content
Draft
106 changes: 106 additions & 0 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)

Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)

To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""
import torch
from einops import rearrange

from vllm.utils import direct_register_custom_op


def xformers_attn_seqlens_wrapper(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor,
seqlens: torch.Tensor) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens.tolist(),
kv_seqlen=None,
device=q.device)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
return context_layer


def xformers_attn_seqlens_wrapper_fake(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor,
seqlens: torch.Tensor) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)


direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)


def vit_xformers_attn_wrapper(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor,
seqlens: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)


def flash_attn_maxseqlen_wrapper(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor, batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False)
context_layer = rearrange(output, "(b s) h d -> s b (h d)",
b=batch_size).contiguous()
return context_layer


def flash_attn_maxseqlen_wrapper_fake(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int, is_rocm_aiter: bool,
use_upstream_fa: bool) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)


direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)


def vit_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, is_rocm_aiter: bool,
use_upstream_fa: bool) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(q, k, v, cu_seqlens,
max_seqlen, batch_size,
is_rocm_aiter,
use_upstream_fa)
1 change: 0 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def configure_post_pass(self):
inductor_config[PASS_KEY] = self.post_grad_pass_manager

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
Expand Down
39 changes: 31 additions & 8 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationLevel, VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
Expand Down Expand Up @@ -192,13 +192,22 @@ def _support_torch_compile(
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__

setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
def __init__(self, **kwargs):
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we only want to pass
# vllm_config when it is available.
sig = inspect.signature(old_init)
if 'vllm_config' in sig.parameters:
vllm_config = kwargs['vllm_config']
else:
vllm_config = get_current_vllm_config()

old_init(self, **kwargs)

self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
Expand All @@ -212,20 +221,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
return

compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
if not hasattr(self.__class__, "compiled_callable"):
print(f"init self for {self.__class__}")
# only compile the same model once
# NOTE: this is probably not right, since parameters can change
# and cause us to fall over
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
self.__class__.compiled_callable = self.compiled_callable
else:
print("init reusing the callable")
TorchCompileWrapperWithCustomDispatcher.__init__(
self,
self.__class__.compiled_callable,
compilation_level=vllm_config.compilation_config.level)

cls.__init__ = __init__

def __call__(self, *args, **kwargs):
print(f"Call to {self.__class__} forward")
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)

# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
if len(self.__class__.compiled_codes) < 1:
sig = inspect.signature(self.__class__.forward)
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
Expand Down Expand Up @@ -259,7 +281,8 @@ def __call__(self, *args, **kwargs):
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
if len(self.__class__.compiled_codes
) < 1 or not self.use_custom_dispatcher:
# it seems Dynamo reuse the compilation across instances,
# while we need to make sure the compiled code is not reused.
# we need to control all the compilation of the model.
Expand Down
9 changes: 5 additions & 4 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,

self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: list[CodeType] = []
self.__class__.compiled_codes = [] # type: ignore[attr-defined]
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

# read the env var to determine whether to use the custom dispatcher
Expand Down Expand Up @@ -91,8 +91,8 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
if frame.f_locals["self"] is not self:
return

self.compiled_codes.append(new_code)

self.__class__.compiled_codes.append( # type: ignore[attr-defined]
new_code)
path = self.vllm_config.compile_debug_dump_path()
if path:
decompiled_file = path / "transformed_code.py"
Expand Down Expand Up @@ -130,6 +130,7 @@ def dispatch_to_code(self, index: int):

See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
self.__class__.forward.__code__ = self.__class__.compiled_codes[ # type: ignore[attr-defined]
index]
yield
self.__class__.forward.__code__ = self.original_code_object
1 change: 1 addition & 0 deletions vllm/model_executor/models/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ def load_weights(self, weights: Iterable[tuple[str,


class Gemma3nForCausalLM(nn.Module):

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down
Loading