Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import torch
import torch.nn as nn
from packaging.version import Version
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor

Expand Down Expand Up @@ -716,6 +718,27 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loader.load_weights(weights)


class MantisProcessingInfo(LlavaProcessingInfo):

def get_hf_processor(self):
hf_config = self.get_hf_config()
vision_info = self.get_vision_encoder_info()

if Version(TRANSFORMERS_VERSION) < Version("4.48"):
# BUG: num_additional_image_tokens = 0 but treated as 1,
# so we set vision_feature_select_strategy to None to offset this
vision_feature_select_strategy = None
else:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501

return self.ctx.get_hf_processor(
LlavaProcessor,
patch_size=vision_info.get_patch_size(),
vision_feature_select_strategy=vision_feature_select_strategy,
)


class MantisMultiModalProcessor(LlavaMultiModalProcessor):

def apply(
Expand Down Expand Up @@ -794,7 +817,7 @@ def get_replacement_mantis(item_idx: int):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
info=MantisProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass
72 changes: 49 additions & 23 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand Down Expand Up @@ -153,29 +154,24 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, Any],
) -> BatchFeature:
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])

if audios:
mm_data["audios"] = audios

feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
# Text-only input not supported in composite processor
if not mm_data or not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)

processed_outputs = super()._call_hf_processor(
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
Expand All @@ -192,8 +188,14 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index
processor = self.info.get_hf_processor()

# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token",
"<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token",
"<|audio_eos|>")

feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
Expand All @@ -214,12 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int):
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")

return [placeholder] * num_placeholders
return "".join([
audio_bos_token,
audio_token * num_placeholders,
audio_eos_token,
])

return [
PromptReplacement(
modality="audio",
target=[placeholder],
target=audio_token,
replacement=get_replacement_qwen2_audio,
)
]
Expand All @@ -234,6 +240,26 @@ def _always_apply_prompt_replacements(self) -> bool:
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")

def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

# Only <|AUDIO|> tokens should be considered as placeholders,
# so we ignore the audio_bos_token and audio_eos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}

return result


@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
Expand Down
9 changes: 1 addition & 8 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _call_hf_processor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
if not mm_data or not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
Expand All @@ -146,13 +146,6 @@ def _call_hf_processor(
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)

if not audios:
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict(
**mm_kwargs,
Expand Down
9 changes: 5 additions & 4 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
Cohere2Config, DbrxConfig,
DeepseekVLV2Config, EAGLEConfig,
ExaoneConfig, H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
Expand All @@ -52,6 +52,7 @@
}

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"aria": AriaConfig,
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
Expand All @@ -23,6 +24,7 @@
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

__all__ = [
"AriaConfig",
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",
Expand Down
118 changes: 118 additions & 0 deletions vllm/transformers_utils/configs/aria.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Mapping

from transformers import PretrainedConfig
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig)
from transformers.models.llama.configuration_llama import LlamaConfig

from vllm.logger import init_logger

logger = init_logger(__name__)


class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"
Expand Down Expand Up @@ -45,3 +70,96 @@ def __init__(
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts


class AriaConfig(PretrainedConfig):
"""
Configuration class for Aria model.
This class handles the configuration for both vision and text components of
the Aria model,
as well as additional parameters for image token handling and projector
mapping.
Comment on lines +75 to +81
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to override this in vLLM because the HF repo attempts to import transformers.models.llama.modeling_llama.LLAMA_ATTENTION_CLASSES which no longer exists in transformers v4.48

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @xffxff

Copy link
Member

@Isotr0py Isotr0py Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But Aria has been integrated to transformers in 4.48, I think there is no need to override the model config, because we can directly use transformers implementation which should be compatible with 4.48.

We can just ask users to disable trust_remote_code in 4.48. (If the multimodal processor can still work with transformers processor without any modification)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let me try.

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jan 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, let's keep this for now for backwards compatibility, so vLLM can work on both transformers v4.47 and v4.48 without any changes required by the user.


Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision
component.
text_config (AriaMoELMConfig or dict): Configuration for the text
component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.
Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple
components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
vision_config (AriaVisionConfig): Configuration for the vision
component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""

model_type = "aria"
is_composition = False

def __init__(
self,
vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
projector_patch_to_query_dict: Mapping[int, int] = {
1225: 128,
4900: 256,
},
ignore_index=-100,
image_token_index=32000,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.tie_word_embeddings = tie_word_embeddings
attn_implementation = kwargs.pop("attn_implementation", None)

# Set the default attention implementation to flash_attention_2 if not
# specified
self._attn_implementation = ("flash_attention_2"
if attn_implementation is None else
attn_implementation)

# Convert the keys and values of projector_patch_to_query_dict to
# integers
# This ensures consistency even if they were provided as strings
self.projector_patch_to_query_dict = {
int(k): int(v)
for k, v in projector_patch_to_query_dict.items()
}

if isinstance(vision_config, dict) and "model_type" in vision_config:
vision_config = AriaVisionConfig(**vision_config)
if attn_implementation is None:
vision_attn_implementation = "flash_attention_2"
elif attn_implementation == "sdpa":
logger.warning("SDPA is not supported for vit, using "
"flash_attention_2 instead")
vision_attn_implementation = "flash_attention_2"
else:
vision_attn_implementation = attn_implementation
vision_config._attn_implementation = vision_attn_implementation

self.vision_config = vision_config

if isinstance(text_config, dict) and "model_type" in text_config:
text_attn_implementation = ("sdpa" if attn_implementation is None
else attn_implementation)
text_config = AriaMoELMConfig(**text_config)
text_config._attn_implementation = text_attn_implementation

self.text_config = text_config

# This is needed for the static kv cache
self.num_hidden_layers = self.text_config.num_hidden_layers
Loading