-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model
#23207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c03c85
af86a0d
319d584
1be4959
9850035
a692789
03e503f
408c704
2c3374f
f58d94b
bdfad3c
578ec11
75dd21d
aedf9c3
40644fe
fe6967f
6bddb4d
7ca7fe9
741e395
edf0b64
fcb13d7
c2d155a
ab6d2e1
4927e55
309fb19
5c7829e
b035b09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import pytest | ||
|
|
||
| from vllm.compilation.counter import compilation_counter | ||
| from vllm.config.compilation import CompilationMode | ||
|
|
||
|
|
||
| # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 | ||
| @pytest.mark.forked | ||
| def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch): | ||
| """Test that Qwen2.5-VL vision submodules are compiled. | ||
|
|
||
| This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed, | ||
| Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged | ||
| for compilation by checking that num_models_seen increases by at least 3. | ||
| """ | ||
| # Disable multiprocessing so that the counter is in the same process | ||
| monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") | ||
|
|
||
| with ( | ||
| # NOTE: Qwen2.5-VL has 35 models in total - the LLM backend | ||
| # Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks | ||
| # (one for each layer) - in the future, we should fix vLLM compilation | ||
| # logic to handle this case and only compile the Vision submodules once | ||
| # and reuse the compiled code for all layers | ||
| # See https://github.com/vllm-project/vllm/issues/27590 | ||
| compilation_counter.expect(num_models_seen=35), | ||
| vllm_runner( | ||
| "Qwen/Qwen2.5-VL-3B-Instruct", | ||
| max_model_len=2048, | ||
| gpu_memory_utilization=0.7, | ||
| compilation_config={"mode": CompilationMode.VLLM_COMPILE}, | ||
| ) as _, | ||
| ): | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| This file contains ops for ViT attention to be compatible with torch.compile | ||
| as there are operations here not supported by torch.compile (for instance, | ||
| `to_list` in xformers attn, or `.item()` in flash attention) | ||
| Using these ops and wrapping vision blocks with `torch.compile` can speed up | ||
| throughput in vision models by ~5% relative on H100, and improve token | ||
| latencies by ~7% (see qwen2_5_vl for example usage) | ||
| To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) | ||
| """ | ||
|
|
||
| import einops | ||
| import torch | ||
|
|
||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def xformers_attn_seqlens_wrapper( | ||
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| from xformers import ops as xops | ||
| from xformers.ops.fmha.attn_bias import BlockDiagonalMask | ||
|
|
||
| attn_bias = BlockDiagonalMask.from_seqlens( | ||
| q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device | ||
| ) | ||
| context_layer = xops.memory_efficient_attention_forward( | ||
| q, k, v, attn_bias=attn_bias, p=0, scale=None | ||
| ) | ||
| context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() | ||
| return context_layer | ||
|
|
||
|
|
||
| def xformers_attn_seqlens_wrapper_fake( | ||
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| b, s, h, d = q.shape | ||
| return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="xformers_attn_seqlens_wrapper", | ||
| op_func=xformers_attn_seqlens_wrapper, | ||
| fake_impl=xformers_attn_seqlens_wrapper_fake, | ||
| ) | ||
|
|
||
|
|
||
| def vit_xformers_attn_wrapper( | ||
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens) | ||
|
|
||
|
|
||
| def flash_attn_maxseqlen_wrapper( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| if is_rocm_aiter: | ||
| from aiter import flash_attn_varlen_func | ||
| else: | ||
| if use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func | ||
| else: | ||
| from vllm.vllm_flash_attn import flash_attn_varlen_func | ||
| q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) | ||
| output = flash_attn_varlen_func( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens_q=cu_seqlens, | ||
| cu_seqlens_k=cu_seqlens, | ||
| max_seqlen_q=max_seqlen.item(), | ||
| max_seqlen_k=max_seqlen.item(), | ||
| dropout_p=0.0, | ||
| causal=False, | ||
| ) | ||
| context_layer = einops.rearrange( | ||
| output, "(b s) h d -> s b (h d)", b=batch_size | ||
| ).contiguous() | ||
| return context_layer | ||
|
|
||
|
|
||
| def flash_attn_maxseqlen_wrapper_fake( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| b, s, h, d = q.shape | ||
| return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="flash_attn_maxseqlen_wrapper", | ||
| op_func=flash_attn_maxseqlen_wrapper, | ||
| fake_impl=flash_attn_maxseqlen_wrapper_fake, | ||
| ) | ||
|
|
||
|
|
||
| def vit_flash_attn_wrapper( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| return torch.ops.vllm.flash_attn_maxseqlen_wrapper( | ||
| q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,12 @@ | |
| import vllm.envs as envs | ||
| from vllm.compilation.counter import compilation_counter | ||
| from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher | ||
| from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config | ||
| from vllm.config import ( | ||
| CompilationMode, | ||
| VllmConfig, | ||
| get_current_vllm_config, | ||
| set_current_vllm_config, | ||
| ) | ||
| from vllm.logger import init_logger | ||
| from vllm.sequence import IntermediateTensors | ||
| from vllm.utils.import_utils import resolve_obj_by_qualname | ||
|
|
@@ -74,6 +79,21 @@ def support_torch_compile( | |
| ) -> Callable[[_T], _T]: ... | ||
|
|
||
|
|
||
| @overload | ||
| def support_torch_compile( | ||
| *, | ||
| mark_unbacked_dims: dict[str, int | list[int]] | None, | ||
| ) -> Callable[[_T], _T]: ... | ||
|
Comment on lines
+82
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This arg is used with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @ProExpertProg @zou3519 @ywang96 can y'all have another look at this? I needed to add someway to mark a particular dimension as not only dynamic, but unbacked (so I did that here); without this, compile 0/1 specialized during tracing for when dumy_multimodal_input in gpu_runner returns 1 I tried to follow the convention of the other args for support_torch_compile |
||
|
|
||
|
|
||
| @overload | ||
| def support_torch_compile( | ||
| *, | ||
| dynamic_arg_dims: dict[str, int | list[int]] | None, | ||
| mark_unbacked_dims: dict[str, int | list[int]] | None, | ||
| ) -> Callable[[_T], _T]: ... | ||
|
|
||
|
|
||
| @overload | ||
| def support_torch_compile(cls: _T) -> _T: ... | ||
|
|
||
|
|
@@ -82,6 +102,7 @@ def support_torch_compile( | |
| cls: _T | None = None, | ||
| *, | ||
| dynamic_arg_dims: dict[str, int | list[int]] | None = None, | ||
| mark_unbacked_dims: dict[str, int | list[int]] | None = None, | ||
| enable_if: Callable[[VllmConfig], bool] | None = None, | ||
| ) -> Callable[[_T], _T] | _T: | ||
| """ | ||
|
|
@@ -135,6 +156,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... | |
| returns a boolean value indicating whether to compile the model or not. | ||
| This is useful if you want to compile the model only when certain | ||
| conditions are met. | ||
|
|
||
| `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic | ||
| dim to be decorated with `mark_unbacked`. This is useful if we would like to | ||
| enforce that dynamo do not specialize on 0/1 values in the case of dummy input | ||
| such as for vision model compilation | ||
| """ | ||
|
|
||
| def cls_decorator_helper(cls: _T) -> _T: | ||
|
|
@@ -172,7 +198,9 @@ def cls_decorator_helper(cls: _T) -> _T: | |
| raise ValueError( | ||
| f"Argument {k} not found in the forward method of {cls}" | ||
| ) | ||
| return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if) | ||
| return _support_torch_compile( | ||
| cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if | ||
| ) | ||
|
|
||
| if cls is not None: | ||
| # use `support_torch_compile` as a decorator without arguments | ||
|
|
@@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None: | |
| def _support_torch_compile( | ||
| cls: _T, | ||
| dynamic_arg_dims: dict[str, int | list[int]], | ||
| mark_unbacked_dims: dict[str, int | list[int]] | None = None, | ||
| enable_if: Callable[[VllmConfig], bool] | None = None, | ||
| ) -> _T: | ||
| """ | ||
|
|
@@ -230,8 +259,22 @@ def _support_torch_compile( | |
|
|
||
| setattr(cls, IGNORE_COMPILE_KEY, False) | ||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): | ||
| old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) | ||
| def __init__( | ||
| self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs | ||
| ): | ||
| if vllm_config is None: | ||
| vllm_config = get_current_vllm_config() | ||
|
|
||
| # NOTE: to support multimodal models (such as encoder), | ||
| # we may not have vllm_config so we may need to patch | ||
| # it | ||
| sig = inspect.signature(old_init) | ||
| if "vllm_config" in sig.parameters: | ||
| kwargs["vllm_config"] = vllm_config | ||
| if "prefix" in sig.parameters: | ||
| kwargs["prefix"] = prefix | ||
| old_init(self, **kwargs) | ||
|
|
||
| self.vllm_config = vllm_config | ||
| enable_compile = enable_if is None or enable_if(vllm_config) | ||
| # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner | ||
|
|
@@ -344,6 +387,15 @@ def __call__(self, *args, **kwargs): | |
| "Unsupported dynamic dimensions" | ||
| f" {dims} for argument {k} with type {type(arg)}." | ||
| ) | ||
| if mark_unbacked_dims: | ||
| for k, dims in mark_unbacked_dims.items(): | ||
| arg = bound_args.arguments.get(k) | ||
| if arg is not None: | ||
| dims = [dims] if isinstance(dims, int) else dims | ||
| if isinstance(arg, torch.Tensor): | ||
| # In case dims is specified with negative indexing | ||
| dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] | ||
| torch._dynamo.decorators.mark_unbacked(arg, dims) | ||
| # here, it is the starting point of the `torch.compile` process | ||
| start_monitoring_torch_compile(self.vllm_config) | ||
| logger.debug("Start compiling function %s", self.original_code_object) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.