From 5a30c7f0b5d99c0529d658ab6ddf21714871423c Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 19 Aug 2025 13:49:47 -0700 Subject: [PATCH 01/13] Support compile on any module (towards compilable vision transformer) Signed-off-by: Lucas Kabela --- vllm/compilation/decorators.py | 16 +++++++++++++--- vllm/model_executor/models/qwen2_5_vl.py | 13 +++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fa38cfe49a91..5f14500d8ba5 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -13,7 +13,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname, supports_dynamo @@ -197,8 +197,18 @@ def _support_torch_compile( setattr(cls, IGNORE_COMPILE_KEY, False) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): - old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + def __init__(self, **kwargs): + # NOTE: to support multimodal models (such as encoder), + # we may not have vllm_config so we only want to pass + # vllm_config when it is available. + sig = inspect.signature(old_init) + if 'vllm_config' in sig.parameters: + vllm_config = kwargs['vllm_config'] + else: + vllm_config = get_current_vllm_config() + + old_init(self, **kwargs) + self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index da3889d31a7d..db5705d80593 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -39,9 +39,12 @@ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -498,6 +501,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@set_model_tag("Qwen2_5_VisionPatchMerger") +@support_torch_compile(dynamic_arg_dims={ + "x": 0, +}) class Qwen2_5_VisionPatchMerger(nn.Module): def __init__( @@ -1001,6 +1008,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config + self.vllm_config = vllm_config self.multimodal_config = multimodal_config self.video_pruning_rate = multimodal_config.video_pruning_rate self.is_multimodal_pruning_enabled = ( @@ -1134,8 +1142,9 @@ def _process_image_input( grid_thw_list, rope_type="rope_3d") else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync From f5575946fff37f663c49ae622a9cd4a62ae529b5 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 26 Aug 2025 08:47:51 -0700 Subject: [PATCH 02/13] Changes for benchmarking Signed-off-by: Lucas Kabela --- examples/offline_inference/vision_language.py | 4 +- vllm/model_executor/layers/layernorm.py | 3 -- vllm/model_executor/models/mllama4.py | 52 ++++++++++++++----- vllm/model_executor/models/qwen2_5_vl.py | 19 +++++-- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f8ddb5a22b31..0d13f74cb4bf 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -740,10 +740,10 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_name, - max_model_len=8192, + max_model_len=3128, max_num_seqs=4, tensor_parallel_size=8, - gpu_memory_utilization=0.4, + gpu_memory_utilization=0.7, limit_mm_per_prompt={modality: 1}, ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 363245daa89d..c091fb6cb1a7 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -180,9 +180,6 @@ def forward_native( residual = x.to(orig_dtype) hidden_size = x.shape[-1] - if hidden_size != self.hidden_size: - raise ValueError("Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}") if self.variance_size_override is None: x_var = x diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index db5a9fbc6a33..83a550933d50 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -30,8 +30,12 @@ find_supported_resolutions, get_best_fit) from vllm.attention.layer import MultiHeadAttention +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import set_forward_context +from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -91,6 +95,10 @@ class Llama4ImagePatchInputs(TensorSchema): """ +@set_model_tag("Llama4VisionMLP") +@support_torch_compile(dynamic_arg_dims={ + "hidden_states": 0, +}) class Llama4VisionMLP(nn.Module): def __init__( @@ -133,6 +141,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +@set_model_tag("Llama4MultiModalProjector") +@support_torch_compile(dynamic_arg_dims={ + "image_features": 0, +}) class Llama4MultiModalProjector(nn.Module): def __init__( @@ -272,6 +284,7 @@ def __init__( prefix=f"{prefix}.o_proj", ) + # THIS IS WHAT IS BEING MODIFIED IN PLACE! self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=config.hidden_size // config.num_attention_heads // 2, @@ -306,6 +319,11 @@ def forward( return attn_output +# @set_model_tag("Llama4VisionEncoderLayer") +# @support_torch_compile(dynamic_arg_dims={ +# "hidden_state": 0, +# "hidden_state": 1, +# }) class Llama4VisionEncoderLayer(nn.Module): def __init__( @@ -373,7 +391,7 @@ def __init__( self.config = config self.layers = nn.ModuleList([ Llama4VisionEncoderLayer( - config, + config=config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", use_data_parallel=use_data_parallel, @@ -401,6 +419,10 @@ def forward( return hidden_states +@set_model_tag("Llama4UnfoldConvolution") +@support_torch_compile(dynamic_arg_dims={ + "hidden_states": 0, +}) class Llama4UnfoldConvolution(nn.Module): def __init__( @@ -453,7 +475,7 @@ def __init__( self.scale = config.hidden_size**-0.5 self.patch_embedding = Llama4UnfoldConvolution( - config, + config=config, quant_config=quant_config, prefix=f"{prefix}.patch_embedding", use_data_parallel=use_data_parallel, @@ -470,7 +492,7 @@ def __init__( # encoders self.model = Llama4VisionEncoder( - config, + config=config, quant_config=quant_config, prefix=f"{prefix}.model", use_data_parallel=use_data_parallel, @@ -726,6 +748,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + self.vllm_config = vllm_config config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -736,14 +759,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Llama4VisionModel( - config.vision_config, - None, + config=config.vision_config, + quant_config=None, prefix=maybe_prefix(prefix, "vision_model"), use_data_parallel=self.use_data_parallel, ) self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, + config=self.config, + quant_config=None, prefix=maybe_prefix(prefix, "multi_modal_projector")) else: self.vision_model = None @@ -788,14 +811,15 @@ def _process_image_input( patches_per_image = image_input["patches_per_image"].tolist() # shard image input - if self.use_data_parallel: - vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model) - else: - vision_embeddings_flat = self.vision_model(flat_data) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + vision_embeddings_flat = run_dp_sharded_vision_model( + flat_data, self.vision_model) + else: + vision_embeddings_flat = self.vision_model(flat_data) - vision_embeddings_flat = self.multi_modal_projector( - vision_embeddings_flat) + vision_embeddings_flat = self.multi_modal_projector( + vision_embeddings_flat) return [ img.flatten(0, 1) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index db5705d80593..488b7fff1a86 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -417,6 +417,12 @@ def forward( return output +# @set_model_tag("Qwen2_5_VisionBlock") +# @support_torch_compile(dynamic_arg_dims={ +# "x": 0, +# "cu_seqlens": 0, +# "rotary_pos_emb": 0, +# }) class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -472,6 +478,10 @@ def forward( return x +@set_model_tag("Qwen2_5_VisionPatchEmbed") +@support_torch_compile(dynamic_arg_dims={ + "x": 0, +}) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( @@ -787,7 +797,7 @@ def forward( window_index.append(window_index_thw + window_index_id) window_index_id += (t * llm_h * llm_w) - + assert cu_seqlens_window_thw.size()[0] >= 2 # Just for compile cu_seqlens_window_thw = (cu_seqlens_window_thw + cu_window_seqlens_last) cu_window_seqlens_last = cu_seqlens_window_thw[-1] @@ -1017,7 +1027,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if multimodal_config.get_limit_per_prompt("image") or \ multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( - config.vision_config, + vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), @@ -1202,8 +1212,9 @@ def _process_video_input( grid_thw_list, rope_type="rope_3d") else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size From cd9e380fe43dd2b3b74c0a2f10d9eddd5a39d57a Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 26 Aug 2025 13:24:33 -0700 Subject: [PATCH 03/13] Minimize to just qwen2_5 again for consistency Signed-off-by: Lucas Kabela --- examples/offline_inference/vision_language.py | 4 +- vllm/model_executor/layers/layernorm.py | 3 ++ vllm/model_executor/models/mllama4.py | 51 +++++-------------- vllm/model_executor/models/qwen2_5_vl.py | 8 +-- 4 files changed, 20 insertions(+), 46 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 0d13f74cb4bf..f8ddb5a22b31 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -740,10 +740,10 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_name, - max_model_len=3128, + max_model_len=8192, max_num_seqs=4, tensor_parallel_size=8, - gpu_memory_utilization=0.7, + gpu_memory_utilization=0.4, limit_mm_per_prompt={modality: 1}, ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index c091fb6cb1a7..363245daa89d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -180,6 +180,9 @@ def forward_native( residual = x.to(orig_dtype) hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError("Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}") if self.variance_size_override is None: x_var = x diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 83a550933d50..f06f7b6cacbd 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -30,11 +30,8 @@ find_supported_resolutions, get_best_fit) from vllm.attention.layer import MultiHeadAttention -from vllm.compilation.backends import set_model_tag -from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import set_forward_context from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -95,10 +92,6 @@ class Llama4ImagePatchInputs(TensorSchema): """ -@set_model_tag("Llama4VisionMLP") -@support_torch_compile(dynamic_arg_dims={ - "hidden_states": 0, -}) class Llama4VisionMLP(nn.Module): def __init__( @@ -141,10 +134,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -@set_model_tag("Llama4MultiModalProjector") -@support_torch_compile(dynamic_arg_dims={ - "image_features": 0, -}) class Llama4MultiModalProjector(nn.Module): def __init__( @@ -284,7 +273,6 @@ def __init__( prefix=f"{prefix}.o_proj", ) - # THIS IS WHAT IS BEING MODIFIED IN PLACE! self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=config.hidden_size // config.num_attention_heads // 2, @@ -319,11 +307,6 @@ def forward( return attn_output -# @set_model_tag("Llama4VisionEncoderLayer") -# @support_torch_compile(dynamic_arg_dims={ -# "hidden_state": 0, -# "hidden_state": 1, -# }) class Llama4VisionEncoderLayer(nn.Module): def __init__( @@ -391,7 +374,7 @@ def __init__( self.config = config self.layers = nn.ModuleList([ Llama4VisionEncoderLayer( - config=config, + config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", use_data_parallel=use_data_parallel, @@ -419,10 +402,6 @@ def forward( return hidden_states -@set_model_tag("Llama4UnfoldConvolution") -@support_torch_compile(dynamic_arg_dims={ - "hidden_states": 0, -}) class Llama4UnfoldConvolution(nn.Module): def __init__( @@ -475,7 +454,7 @@ def __init__( self.scale = config.hidden_size**-0.5 self.patch_embedding = Llama4UnfoldConvolution( - config=config, + config, quant_config=quant_config, prefix=f"{prefix}.patch_embedding", use_data_parallel=use_data_parallel, @@ -492,7 +471,7 @@ def __init__( # encoders self.model = Llama4VisionEncoder( - config=config, + config, quant_config=quant_config, prefix=f"{prefix}.model", use_data_parallel=use_data_parallel, @@ -748,7 +727,6 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.vllm_config = vllm_config config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -759,14 +737,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Llama4VisionModel( - config=config.vision_config, - quant_config=None, + config.vision_config, + None, prefix=maybe_prefix(prefix, "vision_model"), use_data_parallel=self.use_data_parallel, ) self.multi_modal_projector = Llama4MultiModalProjector( - config=self.config, - quant_config=None, + self.config, + None, prefix=maybe_prefix(prefix, "multi_modal_projector")) else: self.vision_model = None @@ -811,15 +789,14 @@ def _process_image_input( patches_per_image = image_input["patches_per_image"].tolist() # shard image input - with set_forward_context(None, self.vllm_config): - if self.use_data_parallel: - vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model) - else: - vision_embeddings_flat = self.vision_model(flat_data) + if self.use_data_parallel: + vision_embeddings_flat = run_dp_sharded_vision_model( + flat_data, self.vision_model) + else: + vision_embeddings_flat = self.vision_model(flat_data) - vision_embeddings_flat = self.multi_modal_projector( - vision_embeddings_flat) + vision_embeddings_flat = self.multi_modal_projector( + vision_embeddings_flat) return [ img.flatten(0, 1) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 488b7fff1a86..91bcc2041e1c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -417,12 +417,6 @@ def forward( return output -# @set_model_tag("Qwen2_5_VisionBlock") -# @support_torch_compile(dynamic_arg_dims={ -# "x": 0, -# "cu_seqlens": 0, -# "rotary_pos_emb": 0, -# }) class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -797,7 +791,7 @@ def forward( window_index.append(window_index_thw + window_index_id) window_index_id += (t * llm_h * llm_w) - assert cu_seqlens_window_thw.size()[0] >= 2 # Just for compile + cu_seqlens_window_thw = (cu_seqlens_window_thw + cu_window_seqlens_last) cu_window_seqlens_last = cu_seqlens_window_thw[-1] From 3f4c3a867d7f4351fc3ce14a73cbbf22a669d702 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 26 Aug 2025 14:44:35 -0700 Subject: [PATCH 04/13] Update dep Signed-off-by: Lucas Kabela --- vllm/compilation/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 5f14500d8ba5..970f50accafb 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -13,7 +13,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import CompilationLevel, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname, supports_dynamo From e5cd48799233b665c8a7ea47bf7bb54975d9be89 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 9 Sep 2025 14:33:25 -0700 Subject: [PATCH 05/13] Working version Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 159 ++++++++++++++--------- 1 file changed, 99 insertions(+), 60 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 91bcc2041e1c..785190f4d5fd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -338,8 +338,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -356,67 +356,87 @@ def forward( qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() - elif self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask + # torch.library.opcheck(torch.ops.mylib.attn_executor, (q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, out_dim, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, self.attn_backend == _Backend.XFORMERS), test_utils='test_faketensor') - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = torch.ops.mylib.attn_executor(q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, self.attn_backend == _Backend.XFORMERS) output, _ = self.proj(context_layer) return output +@torch.library.custom_op("mylib::attn_executor", mutates_args=()) +def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, + is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, is_xformers: bool) -> torch.Tensor: + if is_flash_attn: + if is_rocm_aiter: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen.item(), + max_seqlen_k=max_seqlen.item(), + dropout_p=0.0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif is_sdpa: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif is_xformers: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens.tolist(), + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + else: + raise NotImplementedError("Attention type is not supported") + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + return context_layer + +@torch.library.register_fake("mylib::attn_executor") +def attn_executor_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, + is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, is_xformers: bool) -> torch.Tensor: + size = torch.empty((q.shape[1], batch_size, q.shape[2] * q.shape[3],), + dtype=q.dtype, + device=q.device) + return size + +@set_model_tag("Qwen2_5_VisionBlock") +@support_torch_compile(dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, +}) class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -459,8 +479,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn(self.norm1(x), cu_seqlens=cu_seqlens, @@ -586,6 +606,12 @@ def forward(self, seqlen: int) -> torch.Tensor: return self._freqs_cached[:seqlen] +# @set_model_tag("Qwen2_5_VisionBlock") +# @support_torch_compile(dynamic_arg_dims={ +# "x": 0, +# "cu_seqlens": 0, +# "rotary_pos_emb": 0, +# }) class Qwen2_5_VisionTransformer(nn.Module): def __init__( @@ -760,6 +786,19 @@ def invert_permutation(perm: torch.Tensor) -> torch.Tensor: dtype=perm.dtype) return inv + def compute_attn_mask_seqlen_tensor( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen, seqlens = torch.zeros(1, device=cu_seqlens.device), torch.zeros(1, device=cu_seqlens.device) + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]) + return max_seqlen, seqlens + + def forward( self, x: torch.Tensor, @@ -813,9 +852,9 @@ def forward( # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen_tensor( cu_seqlens) - max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( + max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen_tensor( cu_window_seqlens) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) From e0d70bde9cdc0715ccd1a56fb6edd6369d2e4b23 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 9 Sep 2025 14:33:37 -0700 Subject: [PATCH 06/13] Working version2 Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 59 +++++++++++++++--------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 785190f4d5fd..845f32219ca9 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -358,16 +358,23 @@ def forward( # torch.library.opcheck(torch.ops.mylib.attn_executor, (q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, out_dim, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, self.attn_backend == _Backend.XFORMERS), test_utils='test_faketensor') - - context_layer = torch.ops.mylib.attn_executor(q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, self.attn_backend == _Backend.XFORMERS) + context_layer = torch.ops.mylib.attn_executor( + q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, + self.is_flash_attn_backend, + self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == _Backend.TORCH_SDPA, + self.attn_backend == _Backend.XFORMERS) output, _ = self.proj(context_layer) return output @torch.library.custom_op("mylib::attn_executor", mutates_args=()) -def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, - is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, is_xformers: bool) -> torch.Tensor: +def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + seqlens: torch.Tensor, batch_size: int, is_flash_attn: bool, + is_rocm_aiter: bool, is_sdpa: bool, + is_xformers: bool) -> torch.Tensor: if is_flash_attn: if is_rocm_aiter: from aiter import flash_attn_varlen_func @@ -386,9 +393,7 @@ def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: dropout_p=0.0, causal=False) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif is_sdpa: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -399,11 +404,11 @@ def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) + for x in [q_i, k_i, v_i]) output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + k_i, + v_i, + dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) @@ -412,25 +417,34 @@ def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: from xformers.ops.fmha.attn_bias import BlockDiagonalMask attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens.tolist(), - kv_seqlen=None, - device=q.device) + kv_seqlen=None, + device=q.device) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) else: raise NotImplementedError("Attention type is not supported") context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + "b s h d -> s b (h d)").contiguous() return context_layer + @torch.library.register_fake("mylib::attn_executor") -def attn_executor_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, - is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, is_xformers: bool) -> torch.Tensor: - size = torch.empty((q.shape[1], batch_size, q.shape[2] * q.shape[3],), - dtype=q.dtype, - device=q.device) +def attn_executor_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + seqlens: torch.Tensor, batch_size: int, + is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, + is_xformers: bool) -> torch.Tensor: + size = torch.empty(( + q.shape[1], + batch_size, + q.shape[2] * q.shape[3], + ), + dtype=q.dtype, + device=q.device) return size + @set_model_tag("Qwen2_5_VisionBlock") @support_torch_compile(dynamic_arg_dims={ "x": 0, @@ -479,7 +493,7 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: torch.Tensor, # Only used for Flash Attention + max_seqlen: torch.Tensor, # Only used for Flash Attention seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn(self.norm1(x), @@ -790,7 +804,9 @@ def compute_attn_mask_seqlen_tensor( self, cu_seqlens: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - max_seqlen, seqlens = torch.zeros(1, device=cu_seqlens.device), torch.zeros(1, device=cu_seqlens.device) + max_seqlen, seqlens = torch.zeros( + 1, device=cu_seqlens.device), torch.zeros(1, + device=cu_seqlens.device) if (self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() @@ -798,7 +814,6 @@ def compute_attn_mask_seqlen_tensor( seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]) return max_seqlen, seqlens - def forward( self, x: torch.Tensor, From ea6b92e38ce4f4023ae179439ca8c697dd70779a Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 11 Sep 2025 09:16:08 -0700 Subject: [PATCH 07/13] Tidy changes to qwen Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 43 ++++++++++-------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 845f32219ca9..e9d98bada94d 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -356,9 +356,7 @@ def forward( qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) - # torch.library.opcheck(torch.ops.mylib.attn_executor, (q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, out_dim, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, self.attn_backend == _Backend.XFORMERS), test_utils='test_faketensor') - - context_layer = torch.ops.mylib.attn_executor( + context_layer = torch.ops.mylib.custom_vision_attention( q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, @@ -369,12 +367,12 @@ def forward( return output -@torch.library.custom_op("mylib::attn_executor", mutates_args=()) -def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, - seqlens: torch.Tensor, batch_size: int, is_flash_attn: bool, - is_rocm_aiter: bool, is_sdpa: bool, - is_xformers: bool) -> torch.Tensor: +@torch.library.custom_op("mylib::custom_vision_attention", mutates_args=()) +def custom_vision_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + seqlens: torch.Tensor, batch_size: int, + is_flash_attn: bool, is_rocm_aiter: bool, + is_sdpa: bool, is_xformers: bool) -> torch.Tensor: if is_flash_attn: if is_rocm_aiter: from aiter import flash_attn_varlen_func @@ -429,20 +427,21 @@ def attn_executor(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return context_layer -@torch.library.register_fake("mylib::attn_executor") -def attn_executor_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, - seqlens: torch.Tensor, batch_size: int, - is_flash_attn: bool, is_rocm_aiter: bool, is_sdpa: bool, - is_xformers: bool) -> torch.Tensor: - size = torch.empty(( +@torch.library.register_fake("mylib::custom_vision_attention") +def custom_vision_attention_fake(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + seqlens: torch.Tensor, batch_size: int, + is_flash_attn: bool, is_rocm_aiter: bool, + is_sdpa: bool, + is_xformers: bool) -> torch.Tensor: + return torch.empty(( q.shape[1], batch_size, q.shape[2] * q.shape[3], ), dtype=q.dtype, device=q.device) - return size @set_model_tag("Qwen2_5_VisionBlock") @@ -620,12 +619,6 @@ def forward(self, seqlen: int) -> torch.Tensor: return self._freqs_cached[:seqlen] -# @set_model_tag("Qwen2_5_VisionBlock") -# @support_torch_compile(dynamic_arg_dims={ -# "x": 0, -# "cu_seqlens": 0, -# "rotary_pos_emb": 0, -# }) class Qwen2_5_VisionTransformer(nn.Module): def __init__( @@ -867,9 +860,9 @@ def forward( # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen_tensor( + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( cu_seqlens) - max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen_tensor( + max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( cu_window_seqlens) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) From a9c66a7189d4bd0b6b3e0d2564d2224152bede04 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 11 Sep 2025 09:22:32 -0700 Subject: [PATCH 08/13] Rebase and tidy Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index e9d98bada94d..743896493fc4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -772,18 +772,6 @@ def get_rope_by_thw(self, t, h, w): return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw) - def compute_attn_mask_seqlen( - self, - cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: - max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return max_seqlen, seqlens - @staticmethod def invert_permutation(perm: torch.Tensor) -> torch.Tensor: # building the inverse permutation in O(n) time @@ -793,7 +781,7 @@ def invert_permutation(perm: torch.Tensor) -> torch.Tensor: dtype=perm.dtype) return inv - def compute_attn_mask_seqlen_tensor( + def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: From c113fc60e6631f7874709109fe70d9eed7260f28 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 16 Sep 2025 15:37:56 -0700 Subject: [PATCH 09/13] Rebase on main, more benchmarking Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 743896493fc4..071fc45a47aa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -361,7 +361,7 @@ def forward( self.is_flash_attn_backend, self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == _Backend.TORCH_SDPA, - self.attn_backend == _Backend.XFORMERS) + self.attn_backend == _Backend.XFORMERS, self.use_upstream_fa) output, _ = self.proj(context_layer) return output @@ -372,13 +372,16 @@ def custom_vision_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, is_flash_attn: bool, is_rocm_aiter: bool, - is_sdpa: bool, is_xformers: bool) -> torch.Tensor: + is_sdpa: bool, is_xformers: bool, + use_upstream_fa: bool) -> torch.Tensor: if is_flash_attn: if is_rocm_aiter: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func - + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func(q, @@ -433,8 +436,8 @@ def custom_vision_attention_fake(q: torch.Tensor, k: torch.Tensor, max_seqlen: torch.Tensor, seqlens: torch.Tensor, batch_size: int, is_flash_attn: bool, is_rocm_aiter: bool, - is_sdpa: bool, - is_xformers: bool) -> torch.Tensor: + is_sdpa: bool, is_xformers: bool, + use_upstream_fa: bool) -> torch.Tensor: return torch.empty(( q.shape[1], batch_size, From 3e0bfc8bc92e7386af8e9e24040d06a99fd2dd26 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 25 Sep 2025 15:51:10 -0700 Subject: [PATCH 10/13] Mark seqlen dynamic too Signed-off-by: Lucas Kabela --- vllm/model_executor/models/qwen2_5_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 071fc45a47aa..bbf3562398cf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -452,6 +452,7 @@ def custom_vision_attention_fake(q: torch.Tensor, k: torch.Tensor, "x": 0, "cu_seqlens": 0, "rotary_pos_emb": 0, + "seqlens": 0, }) class Qwen2_5_VisionBlock(nn.Module): From c4bb22b5b49900b2162d951473ed33ad9bf20f7e Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 25 Sep 2025 15:59:34 -0700 Subject: [PATCH 11/13] Remove errenously added import Signed-off-by: Lucas Kabela --- vllm/model_executor/models/mllama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index f06f7b6cacbd..db5a9fbc6a33 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -32,7 +32,6 @@ from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, From 485c3f11af61eb70f6587877e615327afea19dcb Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 30 Sep 2025 10:53:39 -0700 Subject: [PATCH 12/13] Move custom ops to be granular in helper file Signed-off-by: Lucas Kabela --- vllm/attention/ops/vit_attn_wrappers.py | 106 +++++++++++++++ vllm/model_executor/models/qwen2_5_vl.py | 164 ++++++++--------------- 2 files changed, 160 insertions(+), 110 deletions(-) create mode 100644 vllm/attention/ops/vit_attn_wrappers.py diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py new file mode 100644 index 000000000000..ec8d0ae13eaa --- /dev/null +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains ops for ViT attention to be compatible with torch.compile +as there are operations here not supported by torch.compile (for instance, +`to_list` in xformers attn, or `.item()` in flash attention) + +Using these ops and wrapping vision blocks with `torch.compile` can speed up +throughput in vision models by ~5% relative on H100, and improve token +latencies by ~7% (see qwen2_5_vl for example usage) + +To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) +""" +import torch +from einops import rearrange + +from vllm.utils import direct_register_custom_op + + +def xformers_attn_seqlens_wrapper(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + seqlens: torch.Tensor) -> torch.Tensor: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens.tolist(), + kv_seqlen=None, + device=q.device) + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + return context_layer + + +def xformers_attn_seqlens_wrapper_fake(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + seqlens: torch.Tensor) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="xformers_attn_seqlens_wrapper", + op_func=xformers_attn_seqlens_wrapper, + fake_impl=xformers_attn_seqlens_wrapper_fake, +) + + +def vit_xformers_attn_wrapper(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + seqlens: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens) + + +def flash_attn_maxseqlen_wrapper(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, batch_size: int, + is_rocm_aiter: bool, + use_upstream_fa: bool) -> torch.Tensor: + if is_rocm_aiter: + from aiter import flash_attn_varlen_func + else: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen.item(), + max_seqlen_k=max_seqlen.item(), + dropout_p=0.0, + causal=False) + context_layer = rearrange(output, "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + return context_layer + + +def flash_attn_maxseqlen_wrapper_fake(q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, is_rocm_aiter: bool, + use_upstream_fa: bool) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="flash_attn_maxseqlen_wrapper", + op_func=flash_attn_maxseqlen_wrapper, + fake_impl=flash_attn_maxseqlen_wrapper_fake, +) + + +def vit_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + batch_size: int, is_rocm_aiter: bool, + use_upstream_fa: bool) -> torch.Tensor: + return torch.ops.vllm.flash_attn_maxseqlen_wrapper(q, k, v, cu_seqlens, + max_seqlen, batch_size, + is_rocm_aiter, + use_upstream_fa) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index bbf3562398cf..a0cdce6f4855 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -39,6 +39,8 @@ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.ops.vit_attn_wrappers import (vit_flash_attn_wrapper, + vit_xformers_attn_wrapper) from vllm.compilation.backends import set_model_tag from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -356,97 +358,38 @@ def forward( qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) - context_layer = torch.ops.mylib.custom_vision_attention( - q, k, v, cu_seqlens, max_seqlen, seqlens, batch_size, - self.is_flash_attn_backend, - self.attn_backend == _Backend.ROCM_AITER_FA, - self.attn_backend == _Backend.TORCH_SDPA, - self.attn_backend == _Backend.XFORMERS, self.use_upstream_fa) + if self.is_flash_attn_backend: + context_layer = vit_flash_attn_wrapper( + q, k, v, cu_seqlens, max_seqlen, batch_size, + self.attn_backend == _Backend.ROCM_AITER_FA, + self.use_upstream_fa) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + elif self.attn_backend == _Backend.XFORMERS: + context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) return output -@torch.library.custom_op("mylib::custom_vision_attention", mutates_args=()) -def custom_vision_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, - seqlens: torch.Tensor, batch_size: int, - is_flash_attn: bool, is_rocm_aiter: bool, - is_sdpa: bool, is_xformers: bool, - use_upstream_fa: bool) -> torch.Tensor: - if is_flash_attn: - if is_rocm_aiter: - from aiter import flash_attn_varlen_func - else: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen.item(), - max_seqlen_k=max_seqlen.item(), - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif is_sdpa: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - elif is_xformers: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens.tolist(), - kv_seqlen=None, - device=q.device) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - else: - raise NotImplementedError("Attention type is not supported") - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() - return context_layer - - -@torch.library.register_fake("mylib::custom_vision_attention") -def custom_vision_attention_fake(q: torch.Tensor, k: torch.Tensor, - v: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, - seqlens: torch.Tensor, batch_size: int, - is_flash_attn: bool, is_rocm_aiter: bool, - is_sdpa: bool, is_xformers: bool, - use_upstream_fa: bool) -> torch.Tensor: - return torch.empty(( - q.shape[1], - batch_size, - q.shape[2] * q.shape[3], - ), - dtype=q.dtype, - device=q.device) - - @set_model_tag("Qwen2_5_VisionBlock") @support_torch_compile(dynamic_arg_dims={ "x": 0, @@ -776,15 +719,6 @@ def get_rope_by_thw(self, t, h, w): return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw) - @staticmethod - def invert_permutation(perm: torch.Tensor) -> torch.Tensor: - # building the inverse permutation in O(n) time - inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) - inv[perm] = torch.arange(perm.numel(), - device=perm.device, - dtype=perm.dtype) - return inv - def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, @@ -799,6 +733,15 @@ def compute_attn_mask_seqlen( seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]) return max_seqlen, seqlens + @staticmethod + def invert_permutation(perm: torch.Tensor) -> torch.Tensor: + # building the inverse permutation in O(n) time + inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) + inv[perm] = torch.arange(perm.numel(), + device=perm.device, + dtype=perm.dtype) + return inv + def forward( self, x: torch.Tensor, @@ -1178,14 +1121,14 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - else: - with set_forward_context(None, self.vllm_config): + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) @@ -1239,13 +1182,14 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") - else: - with set_forward_context(None, self.vllm_config): + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) From f6b2363f0015d90420bf5578b7465cdf04388da2 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 25 Sep 2025 16:09:46 -0700 Subject: [PATCH 13/13] See title; consolidate so that we only compile a class once Signed-off-by: Lucas Kabela --- vllm/compilation/backends.py | 1 - vllm/compilation/decorators.py | 23 ++++++++++++++++++----- vllm/compilation/wrapper.py | 9 +++++---- vllm/model_executor/models/gemma3n.py | 1 + vllm/model_executor/models/qwen2_5_vl.py | 16 ++++++++-------- 5 files changed, 32 insertions(+), 18 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 335bbda5e4eb..7a0119dd3b61 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -474,7 +474,6 @@ def configure_post_pass(self): inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - vllm_config = self.vllm_config if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 970f50accafb..b1ca05bec824 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -192,7 +192,6 @@ def _support_torch_compile( # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) - old_init = cls.__init__ setattr(cls, IGNORE_COMPILE_KEY, False) @@ -222,12 +221,25 @@ def __init__(self, **kwargs): return compilation_counter.num_models_seen += 1 - TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_level=vllm_config.compilation_config.level) + if not hasattr(self.__class__, "compiled_callable"): + print(f"init self for {self.__class__}") + # only compile the same model once + # NOTE: this is probably not right, since parameters can change + # and cause us to fall over + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + self.__class__.compiled_callable = self.compiled_callable + else: + print("init reusing the callable") + TorchCompileWrapperWithCustomDispatcher.__init__( + self, + self.__class__.compiled_callable, + compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ def __call__(self, *args, **kwargs): + print(f"Call to {self.__class__} forward") # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. @@ -235,7 +247,7 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) # the first compilation needs to have dynamic shapes marked - if len(self.compiled_codes) < 1: + if len(self.__class__.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) bound_args = sig.bind(self, *args, **kwargs) bound_args.apply_defaults() @@ -269,7 +281,8 @@ def __call__(self, *args, **kwargs): # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, # with the overhead of guard evaluation and recompilation. - if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + if len(self.__class__.compiled_codes + ) < 1 or not self.use_custom_dispatcher: # it seems Dynamo reuse the compilation across instances, # while we need to make sure the compiled code is not reused. # we need to control all the compilation of the model. diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 062c9dc27017..3a5732197587 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -53,7 +53,7 @@ def __init__(self, self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: list[CodeType] = [] + self.__class__.compiled_codes = [] # type: ignore[attr-defined] torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) # read the env var to determine whether to use the custom dispatcher @@ -91,8 +91,8 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): if frame.f_locals["self"] is not self: return - self.compiled_codes.append(new_code) - + self.__class__.compiled_codes.append( # type: ignore[attr-defined] + new_code) path = self.vllm_config.compile_debug_dump_path() if path: decompiled_file = path / "transformed_code.py" @@ -130,6 +130,7 @@ def dispatch_to_code(self, index: int): See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. """ # noqa - self.__class__.forward.__code__ = self.compiled_codes[index] + self.__class__.forward.__code__ = self.__class__.compiled_codes[ # type: ignore[attr-defined] + index] yield self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 0b6bccb33498..aa862b670f71 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -1048,6 +1048,7 @@ def load_weights(self, weights: Iterable[tuple[str, class Gemma3nForCausalLM(nn.Module): + packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a0cdce6f4855..7775966f2780 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -452,10 +452,10 @@ def forward( return x -@set_model_tag("Qwen2_5_VisionPatchEmbed") -@support_torch_compile(dynamic_arg_dims={ - "x": 0, -}) +# @set_model_tag("Qwen2_5_VisionPatchEmbed") +# @support_torch_compile(dynamic_arg_dims={ +# "x": 0, +# }) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( @@ -485,10 +485,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -@set_model_tag("Qwen2_5_VisionPatchMerger") -@support_torch_compile(dynamic_arg_dims={ - "x": 0, -}) +# @set_model_tag("Qwen2_5_VisionPatchMerger") +# @support_torch_compile(dynamic_arg_dims={ +# "x": 0, +# }) class Qwen2_5_VisionPatchMerger(nn.Module): def __init__(