Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0dfe3c6
initial commit for h2ovl-mississippi models
cooleel Oct 25, 2024
98dd99c
added offline examples
cooleel Oct 27, 2024
87728ca
format
cooleel Oct 28, 2024
20f2858
Refactor code to eliminate duplicate implementations and avoid redund…
cooleel Oct 29, 2024
b137663
expose use_MSAC for image_pixel_values_mapper
cooleel Oct 30, 2024
ca16ecd
added tests for image preprocessing and model output
cooleel Oct 31, 2024
0473d10
sync to the latest and fix the tests for h2ovl
cooleel Oct 31, 2024
eb58216
fixed format
cooleel Oct 31, 2024
5c2ed33
format checked by format.sh
cooleel Oct 31, 2024
49952d5
Merge remote-tracking branch 'upstream/main' into h2ovl-mississippi
ywang96 Oct 31, 2024
25a1a5b
run yapf alone
ywang96 Nov 1, 2024
e8d8a1a
workaround yapf
ywang96 Nov 1, 2024
cb3bc8b
format
ywang96 Nov 1, 2024
d40eb89
initial commit for h2ovl-mississippi models
cooleel Oct 25, 2024
3d6ed4f
added offline examples
cooleel Oct 27, 2024
c737726
format
cooleel Oct 28, 2024
bc8d3f1
Refactor code to eliminate duplicate implementations and avoid redund…
cooleel Oct 29, 2024
05df66d
expose use_MSAC for image_pixel_values_mapper
cooleel Oct 30, 2024
41ef461
added tests for image preprocessing and model output
cooleel Oct 31, 2024
4792e8c
sync to the latest and fix the tests for h2ovl
cooleel Oct 31, 2024
6d6f20d
fixed format
cooleel Oct 31, 2024
840764d
format checked by format.sh
cooleel Oct 31, 2024
a1846f6
run yapf alone
ywang96 Nov 1, 2024
bafd757
workaround yapf
ywang96 Nov 1, 2024
31ece38
format
ywang96 Nov 1, 2024
356180e
Merge branch 'h2ovl-mississippi' of https://github.com/cooleel/vllm i…
ywang96 Nov 3, 2024
9e27d5d
remove unused code
ywang96 Nov 3, 2024
10f273d
add proper image embeds support and simplify
ywang96 Nov 3, 2024
e68ce30
update typo
ywang96 Nov 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ Text Generation
- :code:`THUDM/glm-4v-9b` etc.
-
- ✅︎
* - :code:`H2OVLChatModel`
- H2OVL
- T + I\ :sup:`E+`
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
-
- ✅︎
* - :code:`InternVLChatModel`
- InternVL2
- T + I\ :sup:`E+`
Expand Down
28 changes: 27 additions & 1 deletion examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str):
return llm, prompt, stop_token_ids


# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
assert modality == "image"

model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompt, stop_token_ids


# InternVL
def run_internvl(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -363,6 +388,7 @@ def run_glm4v(question: str, modality: str):
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"h2ovl_chat": run_h2ovl,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
Expand Down Expand Up @@ -475,4 +501,4 @@ def main(args):
default=16,
help='Number of frames to extract from the video.')
args = parser.parse_args()
main(args)
main(args)
35 changes: 35 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"

Expand Down Expand Up @@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
Expand Down
130 changes: 130 additions & 0 deletions tests/models/decoder_only/vision_language/test_h2ovl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Optional, Tuple

import pytest
import torch
from PIL.Image import Image
from transformers import AutoConfig

# Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper)
from vllm.multimodal.utils import rescale_image_size

models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
"h2oai/h2ovl-mississippi-2b",
]
target_dtype = "bfloat16"


def run_preprocessing_test(
image: Image,
config,
max_dynamic_patch: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Test the image preprocessing and calculate expected blocks."""

if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch

width, height = image.size
use_MSAC = config.use_msac

# Create the mapper function with the provided configuration
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
pixel_values = mapper(image)

# Calculate the expected number of blocks
if use_MSAC:
# First pass
blocks1, _, _, aspect_ratio = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
prior_aspect_ratio=None,
)

# Second pass
blocks2, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=aspect_ratio,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus overlapping
total_blocks = blocks1 + blocks2 - 1

expected_blocks = total_blocks

else:
blocks, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=None,
)
expected_blocks = blocks

if config.use_thumbnail and expected_blocks > 1:
expected_blocks += 1

return pixel_values, expected_blocks


@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
def test_image_preprocessing(image_assets, model_name, size_factors,
max_dynamic_patch):
"""Test image preprocessing pipeline with different configurations."""
# Load the configuration from the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

for asset in image_assets:
image = asset.pil_image
for factor in size_factors:
scaled_image = rescale_image_size(image, factor)

# Test preprocessing and get expected number of blocks
pixel_values, expected_blocks = run_preprocessing_test(
scaled_image, config, max_dynamic_patch)

# Verify output shapes and properties
actual_blocks = pixel_values.shape[0]
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks, got {actual_blocks}")

# Check image dimensions
expected_size = (
3, # Number of channels (C, H, W)
config.vision_config.image_size,
config.vision_config.image_size,
)
for img in pixel_values:
assert img.shape == expected_size, (
f"Expected image size {expected_size}, got {img.shape}")
17 changes: 17 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@
marks=[large_gpu_mark(min_gb=48)],
patch_hf_runner=model_utils.glm_patch_hf_runner,
),
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>\nWhat is the season?",
}),
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"intern_vl": VLMTestInfo(
models=[
"OpenGVLab/InternVL2-1B",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,66 @@ def processor(*args, text="", images=None, **kwargs):
return hf_model


def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for H2OVL."""

class H2OVLProcessor:
"""A simple processor for H2OVL models."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)

# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image,
self.image_size,
self.min_num,
self.max_num,
self.use_thumbnail,
use_MSAC=self.config.use_msac).to(
self.dtype) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt

img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
"<IMG_CONTEXT>")
hf_model.model.img_context_token_id = img_context_token_id
hf_model.processor = H2OVLProcessor(hf_model)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
hf_model.model.generate = types.MethodType(_internvl_generate,
hf_model.model)
return hf_model


def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""

Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
Expand Down
Loading