Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):

# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down
7 changes: 2 additions & 5 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest
from transformers import AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.utils import is_flash_attn_2_available

from vllm.platforms import current_platform
from vllm.utils import identity
Expand Down Expand Up @@ -140,9 +139,7 @@
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
Expand All @@ -158,8 +155,8 @@
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
TRANSFORMERS_VERSION < "4.48.0",
reason="HF model requires transformers>=4.48.0",
),
large_gpu_mark(min_gb=64),
],
Expand Down
12 changes: 5 additions & 7 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.multimodal.utils import cached_get_tokenizer

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS


def _test_processing_correctness(
Expand All @@ -20,12 +21,9 @@ def _test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
Expand All @@ -41,7 +39,7 @@ def _test_processing_correctness(
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
hf_overrides=model_info.hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)

Expand Down
67 changes: 62 additions & 5 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Mapping, Optional
from typing import AbstractSet, Any, Literal, Mapping, Optional

import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION


@dataclass(frozen=True)
Expand Down Expand Up @@ -38,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""

hf_overrides: dict[str, Any] = field(default_factory=dict)
"""The ``hf_overrides`` required to load the model."""

def check_transformers_version(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if self.min_transformers_version is None:
return

current_version = TRANSFORMERS_VERSION
required_version = self.min_transformers_version
if Version(current_version) < Version(required_version):
msg = (
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")

if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)

def check_available_online(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the model is not available online, perform the given action.
"""
if not self.is_available_online:
msg = "Model is not available online"

if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)


# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
Expand All @@ -48,8 +96,6 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
Expand Down Expand Up @@ -175,14 +221,17 @@ class _HfExamplesInfo:

_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
min_transformers_version="4.48"),
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
extras={"text_only": "THUDM/chatglm3-6b"},
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
Expand All @@ -193,7 +242,8 @@ class _HfExamplesInfo:
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand Down Expand Up @@ -246,5 +296,12 @@ def get_supported_archs(self) -> AbstractSet[str]:
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]

def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values():
if info.default == model_id:
return info

raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
14 changes: 2 additions & 12 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from unittest.mock import patch

import pytest
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm import LLM

Expand All @@ -13,16 +11,8 @@
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if not model_info.is_available_online:
pytest.skip("Model is not available online")
if model_info.min_transformers_version is not None:
current_version = TRANSFORMERS_VERSION
required_version = model_info.min_transformers_version
if Version(current_version) < Version(required_version):
pytest.skip(
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")

# Ensure all model classes can be imported successfully
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

Expand Down
Loading
Loading