diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 5b97eced62df..d8b94788b18e 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig +from transformers.models.aria import AriaTextConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig @@ -26,8 +27,6 @@ BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, - AriaVisionConfig) from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsMultiModal @@ -39,7 +38,7 @@ class AriaImagePixelInputs(TypedDict): pixel_values: torch.Tensor - pixel_mask: Optional[torch.Tensor] + patch_attention_mask: Optional[torch.Tensor] """ Shape: pixel_values: `(batch_size * num_images, num_channels, height, width)` @@ -47,81 +46,6 @@ class AriaImagePixelInputs(TypedDict): """ -class AriaVisionTransformer(Idefics2VisionTransformer): - """ - AriaVisionTransformer is a modified version of Idefics2VisionTransformer - that replaces the post-layernorm with an identity layer. - """ - - def __init__( - self, - config: AriaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, quant_config, prefix) - self.post_layernorm = nn.Identity() - - -class AriaVisionModel(nn.Module): - config_class = AriaVisionConfig - - def __init__( - self, - config: AriaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - *, - prefix: str = "", - ) -> None: - super().__init__() - - self.vision_model = AriaVisionTransformer( - config, - quant_config, - prefix=f"{prefix}.vision_model", - ) - - def forward( - self, - pixel_values: torch.Tensor, - pixel_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - patch_attention_mask = self._create_patch_attention_mask(pixel_mask) - - vit_oup = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - image_atts = self._create_image_attention_mask(patch_attention_mask) - - return vit_oup, image_atts - - def _create_patch_attention_mask( - self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: - if pixel_mask is None: - return None - - patches_subgrid = pixel_mask.unfold( - dimension=1, - size=self.vision_model.config.patch_size, - step=self.vision_model.config.patch_size, - ).unfold( - dimension=2, - size=self.vision_model.config.patch_size, - step=self.vision_model.config.patch_size, - ) - return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - def _create_image_attention_mask( - self, patch_attention_mask: torch.Tensor) -> torch.Tensor: - if patch_attention_mask is None: - return None - - flattened_mask = patch_attention_mask.flatten(1) - return torch.logical_not(flattened_mask) - - class FFN(nn.Module): def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None: @@ -150,7 +74,7 @@ def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None: self.linear = nn.Linear(embed_dim, embed_dim) self.layer_norm = nn.LayerNorm(embed_dim) - self.ln_kv = nn.LayerNorm(kv_dim) + self.layer_norm_kv = nn.LayerNorm(kv_dim) def forward( self, @@ -161,7 +85,7 @@ def forward( normed_hidden_states = self.layer_norm(hidden_states) query = self.q_proj(normed_hidden_states).permute(1, 0, 2) - x = self.ln_kv(x) + x = self.layer_norm_kv(x) key = self.k_proj(x).permute(1, 0, 2) value = self.v_proj(x).permute(1, 0, 2) @@ -218,8 +142,8 @@ def __init__( self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) - self.ln_ffn = norm_layer(embed_dim) - self.ffn = FFN(embed_dim, ff_dim, output_dim) + self.layer_norm = norm_layer(embed_dim) + self.feed_forward = FFN(embed_dim, ff_dim, output_dim) def forward( self, @@ -241,7 +165,7 @@ def forward( attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) - out = self.ffn(self.ln_ffn(attention_out)) + out = self.feed_forward(self.layer_norm(attention_out)) return out @@ -289,7 +213,7 @@ class MoELayer(nn.Module): def __init__( self, - config: AriaMoELMConfig, + config: AriaTextConfig, quant_config: Optional[QuantizationConfig], ) -> None: super().__init__() @@ -303,13 +227,13 @@ def __init__( num_experts=config.moe_num_experts, top_k=config.moe_topk, hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, + intermediate_size=config.intermediate_size, quant_config=quant_config, reduce_results=True, ) self.shared_experts = LlamaMLP( config.hidden_size, - config.moe_intermediate_size * config.moe_num_shared_experts, + config.intermediate_size * config.moe_num_shared_experts, "silu", quant_config=quant_config, ) @@ -344,7 +268,7 @@ class MoEDecoderLayer(LlamaDecoderLayer): def __init__( self, - config: AriaMoELMConfig, + config: AriaTextConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -450,7 +374,7 @@ class AriaProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() - def get_vision_config(self) -> AriaVisionConfig: + def get_vision_config(self): return self.get_hf_config().vision_config def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -483,8 +407,8 @@ def get_dummy_processor_inputs( num_images=num_images) } - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token # type: ignore + # hf_processor = self.info.get_hf_processor() + image_token: str = '<|img|>' return ProcessorInputs( prompt_text=image_token * num_images, @@ -554,7 +478,7 @@ def __init__( quant_config = vllm_config.quant_config self.config = config - self.vision_tower = AriaVisionModel(config.vision_config) + self.vision_tower = Idefics2VisionTransformer(config.vision_config) self.multi_modal_projector = build_mm_projector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AriaMoELMModel( @@ -581,6 +505,30 @@ def _validate_image_sizes( raise ValueError("All images must be the same size") return images + def _create_patch_attention_mask( + self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.config.vision_config.patch_size, + step=self.config.vision_config.patch_size, + ).unfold( + dimension=2, + size=self.config.vision_config.patch_size, + step=self.config.vision_config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _create_image_attention_mask( + self, patch_attention_mask: torch.Tensor) -> torch.Tensor: + if patch_attention_mask is None: + return None + + flattened_mask = patch_attention_mask.flatten(1) + return torch.logical_not(flattened_mask) + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[AriaImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -596,6 +544,7 @@ def _parse_and_validate_image_input( pixel_values = self._validate_image_sizes(pixel_values) pixel_values = flatten_bn(pixel_values, concat=True) + patch_attention_mask = None if pixel_mask is not None: if not isinstance(pixel_mask, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel mask. " @@ -603,9 +552,12 @@ def _parse_and_validate_image_input( pixel_mask = flatten_bn(pixel_mask, concat=True) + patch_attention_mask = self._create_patch_attention_mask( + pixel_mask) + return AriaImagePixelInputs( pixel_values=pixel_values, - pixel_mask=pixel_mask, + patch_attention_mask=patch_attention_mask, ) def _process_image_input( @@ -614,10 +566,12 @@ def _process_image_input( assert self.vision_tower is not None pixel_values = image_input['pixel_values'] - pixel_mask = image_input['pixel_mask'] + patch_attention_mask = image_input['patch_attention_mask'] - image_feature, image_attn_mask = self.vision_tower( - pixel_values, pixel_mask=pixel_mask) + image_feature = self.vision_tower( + pixel_values, patch_attention_mask=patch_attention_mask) + image_attn_mask = self._create_image_attention_mask( + patch_attention_mask) return self.multi_modal_projector(image_feature, image_attn_mask) def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f57dfded0a62..c97acffa1a71 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,10 +22,10 @@ from vllm.logger import init_logger # yapf conflicts with isort for this block # yapf: disable -from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig, - Cohere2Config, DbrxConfig, - DeepseekVLV2Config, EAGLEConfig, - ExaoneConfig, H2OVLChatConfig, +from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, + DbrxConfig, DeepseekVLV2Config, + EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,7 +52,6 @@ } _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - "aria": AriaConfig, "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 807ef4fbfd0c..2e8c4832a58f 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,3 @@ -from vllm.transformers_utils.configs.aria import AriaConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py deleted file mode 100644 index f4b531225b5d..000000000000 --- a/vllm/transformers_utils/configs/aria.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2024 Rhymes AI. All rights reserved. -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Mapping - -from transformers import PretrainedConfig -from transformers.models.idefics2.configuration_idefics2 import ( - Idefics2VisionConfig) -from transformers.models.llama.configuration_llama import LlamaConfig - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class AriaVisionConfig(Idefics2VisionConfig): - model_type = "aria_vision_model" - - -class AriaMoELMConfig(LlamaConfig): - """ - Configuration class for AriaMoE language model. - - This class extends the LlamaConfig to include additional parameters specific - to the Mixture of Experts (MoE) architecture. - """ - - model_type = "aria_moe_lm" - - def __init__( - self, - moe_intermediate_size: int = 4096, - moe_num_experts: int = 8, - moe_topk: int = 2, - moe_num_shared_experts: int = 2, - **kwargs, - ): - """ - Initialize the AriaMoELMConfig. - - Args: - moe_intermediate_size (int): The intermediate size for MoE layers. - Default is 4096. - moe_num_experts (int): The number of experts in the MoE layer. - Default is 8. - moe_topk (int): The number of top experts to route to for each - token. Default is 2. - moe_num_shared_experts (int): The number of shared experts. Default - is 2. - **kwargs: Additional keyword arguments to be passed to the parent - LlamaConfig. - """ - super().__init__(**kwargs) - self.moe_intermediate_size = moe_intermediate_size - self.moe_num_experts = moe_num_experts - self.moe_topk = moe_topk - self.moe_num_shared_experts = moe_num_shared_experts - - -class AriaConfig(PretrainedConfig): - """ - Configuration class for Aria model. - This class handles the configuration for both vision and text components of - the Aria model, - as well as additional parameters for image token handling and projector - mapping. - - Args: - vision_config (AriaVisionConfig or dict): Configuration for the vision - component. - text_config (AriaMoELMConfig or dict): Configuration for the text - component. - projector_patch_to_query_dict (dict): Mapping of patch sizes to query - dimensions. - ignore_index (int): Index to ignore in loss calculation. - image_token_index (int): Index used to represent image tokens. - **kwargs: Additional keyword arguments passed to the parent class. - Attributes: - model_type (str): Type of the model, set to "aria". - is_composition (bool): Whether the model is a composition of multiple - components. - ignore_index (int): Index to ignore in loss calculation. - image_token_index (int): Index used to represent image tokens. - projector_patch_to_query_dict (dict): Mapping of patch sizes to query - dimensions. - vision_config (AriaVisionConfig): Configuration for the vision - component. - text_config (AriaMoELMConfig): Configuration for the text component. - """ - - model_type = "aria" - is_composition = False - - def __init__( - self, - vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008 - text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008 - projector_patch_to_query_dict: Mapping[int, int] = { - 1225: 128, - 4900: 256, - }, - ignore_index=-100, - image_token_index=32000, - tie_word_embeddings=False, - **kwargs, - ): - super().__init__(**kwargs) - self.ignore_index = ignore_index - self.image_token_index = image_token_index - self.tie_word_embeddings = tie_word_embeddings - attn_implementation = kwargs.pop("attn_implementation", None) - - # Set the default attention implementation to flash_attention_2 if not - # specified - self._attn_implementation = ("flash_attention_2" - if attn_implementation is None else - attn_implementation) - - # Convert the keys and values of projector_patch_to_query_dict to - # integers - # This ensures consistency even if they were provided as strings - self.projector_patch_to_query_dict = { - int(k): int(v) - for k, v in projector_patch_to_query_dict.items() - } - - if isinstance(vision_config, dict) and "model_type" in vision_config: - vision_config = AriaVisionConfig(**vision_config) - if attn_implementation is None: - vision_attn_implementation = "flash_attention_2" - elif attn_implementation == "sdpa": - logger.warning("SDPA is not supported for vit, using " - "flash_attention_2 instead") - vision_attn_implementation = "flash_attention_2" - else: - vision_attn_implementation = attn_implementation - vision_config._attn_implementation = vision_attn_implementation - - self.vision_config = vision_config - - if isinstance(text_config, dict) and "model_type" in text_config: - text_attn_implementation = ("sdpa" if attn_implementation is None - else attn_implementation) - text_config = AriaMoELMConfig(**text_config) - text_config._attn_implementation = text_attn_implementation - - self.text_config = text_config - - # This is needed for the static kv cache - self.num_hidden_layers = self.text_config.num_hidden_layers