diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 69228bbf2294..f9048c7735eb 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -26,11 +26,8 @@ def run_aria(question: str, modality: str): # NOTE: Need L40 (or equivalent) to avoid OOM llm = LLM(model=model_name, - tokenizer_mode="slow", - dtype="bfloat16", max_model_len=4096, max_num_seqs=2, - trust_remote_code=True, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) prompt = (f"<|im_start|>user\n<|img|>\n{question}" diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index ca572cc39e53..14d9a739be31 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -10,7 +10,6 @@ import pytest from transformers import AutoModelForVision2Seq from transformers import __version__ as TRANSFORMERS_VERSION -from transformers.utils import is_flash_attn_2_available from vllm.platforms import current_platform from vllm.utils import identity @@ -140,9 +139,7 @@ #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], - tokenizer_mode="slow", test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - dtype="bfloat16", prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 img_idx_to_prompt=lambda idx: "<|img|>\n", max_model_len=4096, @@ -158,8 +155,8 @@ max_tokens=64, marks=[ pytest.mark.skipif( - not is_flash_attn_2_available(), - reason="Model needs flash-attn for numeric convergence.", + TRANSFORMERS_VERSION < "4.48.0", + reason="HF model requires transformers>=4.48.0", ), large_gpu_mark(min_gb=64), ], diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 1e3e7ea50b12..d6d3d3b34ad4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -11,6 +11,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from ....multimodal.utils import random_audio, random_image, random_video +from ...registry import HF_EXAMPLE_MODELS def _test_processing_correctness( @@ -20,12 +21,9 @@ def _test_processing_correctness( num_batches: int, simplify_rate: float, ): - if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": - hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} - elif model_id == "deepseek-ai/deepseek-vl2-tiny": - hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]} - else: - hf_overrides = {} + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") limit_mm_per_prompt = { modality: 3 if supports_multi else 1 @@ -41,7 +39,7 @@ def _test_processing_correctness( seed=0, dtype="float16", revision=None, - hf_overrides=hf_overrides, + hf_overrides=model_info.hf_overrides, limit_mm_per_prompt=limit_mm_per_prompt, ) diff --git a/tests/models/registry.py b/tests/models/registry.py index cb0521cfe80a..8bdb6d8f632d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field -from typing import AbstractSet, Mapping, Optional +from typing import AbstractSet, Any, Literal, Mapping, Optional + +import pytest +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION @dataclass(frozen=True) @@ -38,6 +42,50 @@ class _HfExamplesInfo: trust_remote_code: bool = False """The ``trust_remote_code`` level required to load the model.""" + hf_overrides: dict[str, Any] = field(default_factory=dict) + """The ``hf_overrides`` required to load the model.""" + + def check_transformers_version( + self, + *, + on_fail: Literal["error", "skip"], + ) -> None: + """ + If the installed transformers version does not meet the requirements, + perform the given action. + """ + if self.min_transformers_version is None: + return + + current_version = TRANSFORMERS_VERSION + required_version = self.min_transformers_version + if Version(current_version) < Version(required_version): + msg = ( + f"You have `transformers=={current_version}` installed, but " + f"`transformers>={required_version}` is required to run this " + "model") + + if on_fail == "error": + raise RuntimeError(msg) + else: + pytest.skip(msg) + + def check_available_online( + self, + *, + on_fail: Literal["error", "skip"], + ) -> None: + """ + If the model is not available online, perform the given action. + """ + if not self.is_available_online: + msg = "Model is not available online" + + if on_fail == "error": + raise RuntimeError(msg) + else: + pytest.skip(msg) + # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { @@ -48,8 +96,6 @@ class _HfExamplesInfo: trust_remote_code=True), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), - "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", - trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", @@ -175,6 +221,8 @@ class _HfExamplesInfo: _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] + "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", + min_transformers_version="4.48"), "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b", @@ -182,7 +230,8 @@ class _HfExamplesInfo: trust_remote_code=True), "ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b", is_available_online=False), - "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501 + "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"), "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", @@ -193,7 +242,8 @@ class _HfExamplesInfo: "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 - "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501 + "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501 + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", @@ -246,5 +296,12 @@ def get_supported_archs(self) -> AbstractSet[str]: def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: return self.hf_models[model_arch] + def find_hf_info(self, model_id: str) -> _HfExamplesInfo: + for info in self.hf_models.values(): + if info.default == model_id: + return info + + raise ValueError(f"No example model defined for {model_id}") + HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index daece7c93c0e..d3a3aaf670c2 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -1,9 +1,7 @@ from unittest.mock import patch import pytest -from packaging.version import Version from transformers import PretrainedConfig -from transformers import __version__ as TRANSFORMERS_VERSION from vllm import LLM @@ -13,16 +11,8 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch): model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) - if not model_info.is_available_online: - pytest.skip("Model is not available online") - if model_info.min_transformers_version is not None: - current_version = TRANSFORMERS_VERSION - required_version = model_info.min_transformers_version - if Version(current_version) < Version(required_version): - pytest.skip( - f"You have `transformers=={current_version}` installed, but " - f"`transformers>={required_version}` is required to run this " - "model") + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") # Avoid OOM def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 73b70d65e8e0..ac0366847e33 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -21,6 +21,9 @@ @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") + # Ensure all model classes can be imported successfully model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 5b97eced62df..503d1a38d9ee 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,9 +1,11 @@ -from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn -from transformers import BatchFeature, PretrainedConfig +from transformers import AriaConfig, AriaTextConfig, BatchFeature +from transformers.models.aria.modeling_aria import AriaCrossAttention +from transformers.models.aria.processing_aria import AriaProcessor from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig @@ -26,10 +28,11 @@ BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, - AriaVisionConfig) -from .idefics2_vision_model import Idefics2VisionTransformer +# yapf: disable +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable from .interfaces import SupportsMultiModal from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, @@ -47,87 +50,22 @@ class AriaImagePixelInputs(TypedDict): """ -class AriaVisionTransformer(Idefics2VisionTransformer): - """ - AriaVisionTransformer is a modified version of Idefics2VisionTransformer - that replaces the post-layernorm with an identity layer. - """ - - def __init__( - self, - config: AriaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, quant_config, prefix) - self.post_layernorm = nn.Identity() - - -class AriaVisionModel(nn.Module): - config_class = AriaVisionConfig +class AriaProjectorMLP(nn.Module): def __init__( self, - config: AriaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - *, - prefix: str = "", + in_features: int, + hidden_features: int, + output_dim: int, ) -> None: super().__init__() - self.vision_model = AriaVisionTransformer( - config, - quant_config, - prefix=f"{prefix}.vision_model", - ) - - def forward( - self, - pixel_values: torch.Tensor, - pixel_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - patch_attention_mask = self._create_patch_attention_mask(pixel_mask) - - vit_oup = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - image_atts = self._create_image_attention_mask(patch_attention_mask) - - return vit_oup, image_atts - - def _create_patch_attention_mask( - self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: - if pixel_mask is None: - return None - - patches_subgrid = pixel_mask.unfold( - dimension=1, - size=self.vision_model.config.patch_size, - step=self.vision_model.config.patch_size, - ).unfold( - dimension=2, - size=self.vision_model.config.patch_size, - step=self.vision_model.config.patch_size, - ) - return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - def _create_image_attention_mask( - self, patch_attention_mask: torch.Tensor) -> torch.Tensor: - if patch_attention_mask is None: - return None - - flattened_mask = patch_attention_mask.flatten(1) - return torch.logical_not(flattened_mask) - - -class FFN(nn.Module): - - def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None: - super().__init__() - self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) - self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) + self.linear_in = ColumnParallelLinear(in_features, + hidden_features, + bias=False) + self.linear_out = RowParallelLinear(hidden_features, + output_dim, + bias=False) self.act = get_act_fn("gelu_new") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -137,46 +75,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class CrossAttention(nn.Module): - - def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None: - super().__init__() - self.num_heads = num_heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) - self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) - - self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - self.linear = nn.Linear(embed_dim, embed_dim) - - self.layer_norm = nn.LayerNorm(embed_dim) - self.ln_kv = nn.LayerNorm(kv_dim) - - def forward( - self, - x: torch.Tensor, - hidden_states: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - normed_hidden_states = self.layer_norm(hidden_states) - query = self.q_proj(normed_hidden_states).permute(1, 0, 2) - - x = self.ln_kv(x) - key = self.k_proj(x).permute(1, 0, 2) - value = self.v_proj(x).permute(1, 0, 2) - - attn_output, _ = self.multihead_attn(query, - key, - value, - attn_mask=attn_mask) - - attn_output = attn_output.permute(1, 0, 2) - - attn_output = self.linear(attn_output) - - return attn_output - - class AriaProjector(nn.Module): """ A projection module with one cross attention layer and one FFN layer, which @@ -198,42 +96,42 @@ class AriaProjector(nn.Module): A tensor with the shape of (batch_size, query_number, output_dim) """ - def __init__( - self, - patch_to_query_dict: dict[int, int], - embed_dim: int, - num_heads: int, - kv_dim: int, - ff_dim: int, - output_dim: int, - norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, - ) -> None: + def __init__(self, config: AriaConfig) -> None: super().__init__() - self.patch_to_query_dict = patch_to_query_dict - self.embed_dim = embed_dim - self.num_heads = num_heads + + self.patch_to_query_dict = config.projector_patch_to_query_dict + self.in_features = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_attention_heads + self.kv_dim = config.vision_config.hidden_size + self.hidden_features = config.text_config.hidden_size + self.output_dim = config.text_config.hidden_size self.query = nn.Parameter( - torch.empty(max(patch_to_query_dict.values()), self.embed_dim)) + torch.empty(config.max_value_projector_patch_to_query_dict, + self.in_features)) - self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + self.cross_attn = AriaCrossAttention(config) - self.ln_ffn = norm_layer(embed_dim) - self.ffn = FFN(embed_dim, ff_dim, output_dim) + self.layer_norm = nn.LayerNorm(self.in_features) + self.feed_forward = AriaProjectorMLP(self.in_features, + self.hidden_features, + self.output_dim) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - bs = x.shape[0] - queries = self.query.unsqueeze(0).repeat(bs, 1, 1) + batch_size, num_patches = x.shape[0], x.shape[1] - query_num = self.patch_to_query_dict.get(x.shape[1], None) - assert (query_num is not None - ), f"Query number for {x.shape[1]} patches is not provided" + if num_patches not in self.patch_to_query_dict: + raise KeyError(f"Number of patches {num_patches} not found in " + "patch_to_query_dict amongst possible values " + f"{self.patch_to_query_dict.keys()}.") - queries = queries[:, :query_num, :] + query_num = self.patch_to_query_dict[num_patches] + + queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) if attn_mask is not None: attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) @@ -241,7 +139,7 @@ def forward( attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) - out = self.ffn(self.ln_ffn(attention_out)) + out = self.feed_forward(self.layer_norm(attention_out)) return out @@ -278,7 +176,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param.data.copy_(loaded_weight.transpose(1, 2)) -class MoELayer(nn.Module): +class AriaTextMoELayer(nn.Module): """ Mixture of Experts (MoE) Layer for the AriaMoE model. @@ -289,7 +187,7 @@ class MoELayer(nn.Module): def __init__( self, - config: AriaMoELMConfig, + config: AriaTextConfig, quant_config: Optional[QuantizationConfig], ) -> None: super().__init__() @@ -303,15 +201,16 @@ def __init__( num_experts=config.moe_num_experts, top_k=config.moe_topk, hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, + intermediate_size=config.intermediate_size, quant_config=quant_config, reduce_results=True, ) self.shared_experts = LlamaMLP( config.hidden_size, - config.moe_intermediate_size * config.moe_num_shared_experts, + config.intermediate_size * config.moe_num_shared_experts, "silu", quant_config=quant_config, + bias=config.mlp_bias, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -329,13 +228,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_output = torch.nn.functional.linear(hidden_states, self.router_weight) - shared_expert_output = self.shared_experts(hidden_states) sparse_expert_output = self.experts(hidden_states, router_output) + shared_expert_output = self.shared_experts(hidden_states) return sparse_expert_output + shared_expert_output -class MoEDecoderLayer(LlamaDecoderLayer): +class AriaTextDecoderLayer(LlamaDecoderLayer): """ Custom Decoder Layer for the AriaMoE model which modifies the standard `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of @@ -344,16 +243,16 @@ class MoEDecoderLayer(LlamaDecoderLayer): def __init__( self, - config: AriaMoELMConfig, + config: AriaTextConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config, cache_config, quant_config, prefix) - self.mlp = MoELayer(config, quant_config=quant_config) + self.mlp = AriaTextMoELayer(config, quant_config=quant_config) -class AriaMoELMModel(LlamaModel): +class AriaTextModel(LlamaModel): """ Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. @@ -362,7 +261,7 @@ class AriaMoELMModel(LlamaModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix, - layer_type=MoEDecoderLayer) + layer_type=AriaTextDecoderLayer) # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` @@ -434,25 +333,23 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -def build_mm_projector(config: PretrainedConfig): - return AriaProjector( - patch_to_query_dict=config.projector_patch_to_query_dict, - embed_dim=config.vision_config.hidden_size, - num_heads=config.vision_config.num_attention_heads, - kv_dim=config.vision_config.hidden_size, - ff_dim=config.text_config.hidden_size, - output_dim=config.text_config.hidden_size, - ) - - class AriaProcessingInfo(BaseProcessingInfo): def get_hf_config(self): - return self.ctx.get_hf_config() + return self.ctx.get_hf_config(AriaConfig) - def get_vision_config(self) -> AriaVisionConfig: + def get_vision_config(self): return self.get_hf_config().vision_config + def get_hf_processor(self): + processor = self.ctx.get_hf_processor(AriaProcessor) + + # Patch for https://github.com/huggingface/transformers/issues/35768 + processor.tokenizer.image_token = "<|img|>" + processor.image_token = "<|img|>" + + return processor + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -554,10 +451,14 @@ def __init__( quant_config = vllm_config.quant_config self.config = config - self.vision_tower = AriaVisionModel(config.vision_config) - self.multi_modal_projector = build_mm_projector(config) + self.vision_tower = Idefics3VisionTransformer( + config.vision_config, + quant_config, + prefix=f"{prefix}.vision_tower", + ) + self.multi_modal_projector = AriaProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AriaMoELMModel( + self.language_model = AriaTextModel( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=maybe_prefix(prefix, "language_model.model"), ) @@ -608,6 +509,22 @@ def _parse_and_validate_image_input( pixel_mask=pixel_mask, ) + def _create_patch_attention_mask( + self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + def _process_image_input( self, image_input: AriaImagePixelInputs ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -616,9 +533,18 @@ def _process_image_input( pixel_values = image_input['pixel_values'] pixel_mask = image_input['pixel_mask'] - image_feature, image_attn_mask = self.vision_tower( - pixel_values, pixel_mask=pixel_mask) - return self.multi_modal_projector(image_feature, image_attn_mask) + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + image_outputs = self.vision_tower( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + return self.multi_modal_projector(image_outputs, image_attn_mask) def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) @@ -683,6 +609,5 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f57dfded0a62..c97acffa1a71 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,10 +22,10 @@ from vllm.logger import init_logger # yapf conflicts with isort for this block # yapf: disable -from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig, - Cohere2Config, DbrxConfig, - DeepseekVLV2Config, EAGLEConfig, - ExaoneConfig, H2OVLChatConfig, +from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, + DbrxConfig, DeepseekVLV2Config, + EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,7 +52,6 @@ } _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - "aria": AriaConfig, "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 807ef4fbfd0c..f065c5612460 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,3 @@ -from vllm.transformers_utils.configs.aria import AriaConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -24,7 +23,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ - "AriaConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py deleted file mode 100644 index f4b531225b5d..000000000000 --- a/vllm/transformers_utils/configs/aria.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2024 Rhymes AI. All rights reserved. -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Mapping - -from transformers import PretrainedConfig -from transformers.models.idefics2.configuration_idefics2 import ( - Idefics2VisionConfig) -from transformers.models.llama.configuration_llama import LlamaConfig - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class AriaVisionConfig(Idefics2VisionConfig): - model_type = "aria_vision_model" - - -class AriaMoELMConfig(LlamaConfig): - """ - Configuration class for AriaMoE language model. - - This class extends the LlamaConfig to include additional parameters specific - to the Mixture of Experts (MoE) architecture. - """ - - model_type = "aria_moe_lm" - - def __init__( - self, - moe_intermediate_size: int = 4096, - moe_num_experts: int = 8, - moe_topk: int = 2, - moe_num_shared_experts: int = 2, - **kwargs, - ): - """ - Initialize the AriaMoELMConfig. - - Args: - moe_intermediate_size (int): The intermediate size for MoE layers. - Default is 4096. - moe_num_experts (int): The number of experts in the MoE layer. - Default is 8. - moe_topk (int): The number of top experts to route to for each - token. Default is 2. - moe_num_shared_experts (int): The number of shared experts. Default - is 2. - **kwargs: Additional keyword arguments to be passed to the parent - LlamaConfig. - """ - super().__init__(**kwargs) - self.moe_intermediate_size = moe_intermediate_size - self.moe_num_experts = moe_num_experts - self.moe_topk = moe_topk - self.moe_num_shared_experts = moe_num_shared_experts - - -class AriaConfig(PretrainedConfig): - """ - Configuration class for Aria model. - This class handles the configuration for both vision and text components of - the Aria model, - as well as additional parameters for image token handling and projector - mapping. - - Args: - vision_config (AriaVisionConfig or dict): Configuration for the vision - component. - text_config (AriaMoELMConfig or dict): Configuration for the text - component. - projector_patch_to_query_dict (dict): Mapping of patch sizes to query - dimensions. - ignore_index (int): Index to ignore in loss calculation. - image_token_index (int): Index used to represent image tokens. - **kwargs: Additional keyword arguments passed to the parent class. - Attributes: - model_type (str): Type of the model, set to "aria". - is_composition (bool): Whether the model is a composition of multiple - components. - ignore_index (int): Index to ignore in loss calculation. - image_token_index (int): Index used to represent image tokens. - projector_patch_to_query_dict (dict): Mapping of patch sizes to query - dimensions. - vision_config (AriaVisionConfig): Configuration for the vision - component. - text_config (AriaMoELMConfig): Configuration for the text component. - """ - - model_type = "aria" - is_composition = False - - def __init__( - self, - vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008 - text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008 - projector_patch_to_query_dict: Mapping[int, int] = { - 1225: 128, - 4900: 256, - }, - ignore_index=-100, - image_token_index=32000, - tie_word_embeddings=False, - **kwargs, - ): - super().__init__(**kwargs) - self.ignore_index = ignore_index - self.image_token_index = image_token_index - self.tie_word_embeddings = tie_word_embeddings - attn_implementation = kwargs.pop("attn_implementation", None) - - # Set the default attention implementation to flash_attention_2 if not - # specified - self._attn_implementation = ("flash_attention_2" - if attn_implementation is None else - attn_implementation) - - # Convert the keys and values of projector_patch_to_query_dict to - # integers - # This ensures consistency even if they were provided as strings - self.projector_patch_to_query_dict = { - int(k): int(v) - for k, v in projector_patch_to_query_dict.items() - } - - if isinstance(vision_config, dict) and "model_type" in vision_config: - vision_config = AriaVisionConfig(**vision_config) - if attn_implementation is None: - vision_attn_implementation = "flash_attention_2" - elif attn_implementation == "sdpa": - logger.warning("SDPA is not supported for vit, using " - "flash_attention_2 instead") - vision_attn_implementation = "flash_attention_2" - else: - vision_attn_implementation = attn_implementation - vision_config._attn_implementation = vision_attn_implementation - - self.vision_config = vision_config - - if isinstance(text_config, dict) and "model_type" in text_config: - text_attn_implementation = ("sdpa" if attn_implementation is None - else attn_implementation) - text_config = AriaMoELMConfig(**text_config) - text_config._attn_implementation = text_attn_implementation - - self.text_config = text_config - - # This is needed for the static kv cache - self.num_hidden_layers = self.text_config.num_hidden_layers