Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
Expand Down Expand Up @@ -90,6 +92,21 @@ def init_test_http_connection():
global_http_connection.reuse_client = False


@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup()


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
Expand Down
80 changes: 80 additions & 0 deletions tests/models/test_intern_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Optional

import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

from vllm.model_executor.models.intern_vit import InternVisionModel

from ..conftest import _ImageAssets, cleanup

pytestmark = pytest.mark.vlm

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [
snapshot_download("OpenGVLab/InternViT-300M-448px",
allow_patterns=DOWNLOAD_PATTERN),
snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5",
allow_patterns=DOWNLOAD_PATTERN),
]


def run_intern_vit_test(
image_assets: _ImageAssets,
model: str,
*,
dtype: str,
distributed_executor_backend: Optional[str] = None,
):
img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets]
pixel_values = [
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
for images in images
]

config = AutoConfig.from_pretrained(model, trust_remote_code=True)
if not getattr(config, "norm_type", None):
config.norm_type = "rms_norm"

hf_model = AutoModel.from_pretrained(model,
torch_dtype=dtype,
trust_remote_code=True).to("cuda")
hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state
for pixel_value in pixel_values
]

vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items())

del hf_model
cleanup()

vllm_model = vllm_model.to("cuda", dtype)
vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values
]
del vllm_model
cleanup()

cos_similar = nn.CosineSimilarity(dim=-1)
for vllm_output, hf_output in zip(vllm_outputs_per_image,
hf_outputs_per_image):
assert cos_similar(vllm_output, hf_output).mean() > 0.99


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode()
def test_models(dist_init, image_assets, model, dtype: str) -> None:
run_intern_vit_test(
image_assets,
model,
dtype=dtype,
)