Skip to content

Commit daef218

Browse files
authored
[Model] Initialize Phi-3-vision support (#4986)
1 parent fa9e385 commit daef218

File tree

8 files changed

+571
-0
lines changed

8 files changed

+571
-0
lines changed

docs/source/models/supported_models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ Alongside each architecture, we include some popular models that use it.
135135
- Phi-3-Small
136136
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
137137
-
138+
* - :code:`Phi3VForCausalLM`
139+
- Phi-3-Vision
140+
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
141+
-
138142
* - :code:`QWenLMHeadModel`
139143
- Qwen
140144
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.

examples/phi3v_example.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import subprocess
3+
4+
from PIL import Image
5+
6+
from vllm import LLM, SamplingParams
7+
from vllm.multimodal.image import ImagePixelData
8+
9+
10+
def run_phi3v():
11+
model_path = "microsoft/Phi-3-vision-128k-instruct"
12+
llm = LLM(
13+
model=model_path,
14+
trust_remote_code=True,
15+
max_model_len=4096,
16+
image_input_type="pixel_values",
17+
image_token_id=32044,
18+
image_input_shape="1,3,1008,1344",
19+
image_feature_size=1921,
20+
disable_image_processor=False,
21+
)
22+
23+
image = Image.open("images/cherry_blossom.jpg")
24+
25+
# single-image prompt
26+
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
27+
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
28+
29+
sampling_params = SamplingParams(temperature=0, max_tokens=64)
30+
31+
outputs = llm.generate({
32+
"prompt": prompt,
33+
"sampling_params": sampling_params,
34+
"multi_modal_data": ImagePixelData(image),
35+
})
36+
for o in outputs:
37+
generated_text = o.outputs[0].text
38+
print(generated_text)
39+
40+
41+
if __name__ == "__main__":
42+
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
43+
local_directory = "images"
44+
45+
# Make sure the local directory exists or create it
46+
os.makedirs(local_directory, exist_ok=True)
47+
48+
# Use AWS CLI to sync the directory, assume anonymous access
49+
subprocess.check_call([
50+
"aws",
51+
"s3",
52+
"sync",
53+
s3_bucket_path,
54+
local_directory,
55+
"--no-sign-request",
56+
])
57+
run_phi3v()

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ peft
1414
requests
1515
ray
1616
sentence-transformers # required for embedding
17+
torchvision # required for the image processor of phi3v
1718

1819
# Benchmarking
1920
aiohttp

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
model_name: str,
145145
dtype: str = "half",
146146
*,
147+
model_kwargs: Optional[Dict[str, Any]] = None,
147148
is_embedding_model: bool = False,
148149
is_vision_model: bool = False,
149150
) -> None:
@@ -166,11 +167,13 @@ def __init__(
166167
else:
167168
auto_cls = AutoModelForCausalLM
168169

170+
model_kwargs = model_kwargs if model_kwargs is not None else {}
169171
self.model = self.wrap_device(
170172
auto_cls.from_pretrained(
171173
model_name,
172174
torch_dtype=torch_dtype,
173175
trust_remote_code=True,
176+
**model_kwargs,
174177
))
175178

176179
self.tokenizer = AutoTokenizer.from_pretrained(

tests/models/test_phi3v.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import List, Tuple
2+
3+
import pytest
4+
from transformers import AutoTokenizer
5+
6+
from vllm.config import VisionLanguageConfig
7+
from vllm.utils import is_cpu
8+
9+
from ..conftest import IMAGE_FILES
10+
11+
pytestmark = pytest.mark.llava
12+
13+
# The image token is placed before "user" on purpose so that the test can pass
14+
HF_IMAGE_PROMPTS = [
15+
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
16+
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
17+
]
18+
19+
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
20+
21+
22+
def iter_phi3v_configs(model_name: str):
23+
image_hw_to_feature_size = {
24+
(1008, 1344): 1921,
25+
}
26+
27+
for (h, w), f in image_hw_to_feature_size.items():
28+
for input_type, input_shape in [
29+
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
30+
]:
31+
yield (model_name,
32+
VisionLanguageConfig(image_input_type=input_type,
33+
image_feature_size=f,
34+
image_token_id=32044,
35+
image_input_shape=input_shape,
36+
image_processor=model_name,
37+
image_processor_revision=None))
38+
39+
40+
model_and_vl_config = [
41+
*iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"),
42+
]
43+
44+
45+
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
46+
vlm_config: VisionLanguageConfig, model_id: str):
47+
"""Sanitize vllm output to be comparable with hf output.
48+
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
49+
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
50+
It also reduces `output_str` from "<image><image>bla" to "bla".
51+
"""
52+
input_ids, output_str = vllm_output
53+
image_token_id = vlm_config.image_token_id
54+
55+
tokenizer = AutoTokenizer.from_pretrained(model_id)
56+
image_token_str = tokenizer.decode(image_token_id)
57+
58+
hf_input_ids = [
59+
input_id if input_id != image_token_id else 0
60+
for idx, input_id in enumerate(input_ids)
61+
]
62+
hf_output_str = output_str \
63+
.replace(image_token_str * vlm_config.image_feature_size, "") \
64+
.replace("<s>", " ").replace("<|user|>", "") \
65+
.replace("<|end|>\n<|assistant|>", " ")
66+
67+
return hf_input_ids, hf_output_str
68+
69+
70+
target_dtype = "half"
71+
if is_cpu():
72+
target_dtype = "bfloat16"
73+
74+
75+
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
76+
# Since we use _attn_implementation="eager" for hf_runner, here is
77+
# numeric difference for longer context and test can't pass
78+
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
79+
@pytest.mark.parametrize("dtype", [target_dtype])
80+
@pytest.mark.parametrize("max_tokens", [8])
81+
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
82+
model_and_config, dtype: str, max_tokens: int) -> None:
83+
"""Inference result should be the same between hf and vllm.
84+
85+
All the image fixtures for the test is under tests/images.
86+
For huggingface runner, we provide the PIL images as input.
87+
For vllm runner, we provide MultiModalData objects and corresponding
88+
vision language config as input.
89+
Note, the text input is also adjusted to abide by vllm contract.
90+
The text output is sanitized to be able to compare with hf.
91+
"""
92+
model_id, vlm_config = model_and_config
93+
94+
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
95+
hf_model_kwargs = {"_attn_implementation": "eager"}
96+
with hf_runner(model_id, dtype=dtype,
97+
model_kwargs=hf_model_kwargs) as hf_model:
98+
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
99+
max_tokens,
100+
images=hf_images)
101+
102+
vllm_image_prompts = [
103+
p.replace("<|image_1|>",
104+
"<|image|>" * vlm_config.image_feature_size + "<s>")
105+
for p in HF_IMAGE_PROMPTS
106+
]
107+
108+
with vllm_runner(model_id,
109+
max_model_len=2048,
110+
dtype=dtype,
111+
enforce_eager=True,
112+
**vlm_config.as_cli_args_dict()) as vllm_model:
113+
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
114+
max_tokens,
115+
images=vllm_images)
116+
117+
for i in range(len(HF_IMAGE_PROMPTS)):
118+
hf_output_ids, hf_output_str = hf_outputs[i]
119+
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
120+
vllm_outputs[i], vlm_config, model_id)
121+
assert hf_output_str == vllm_output_str, (
122+
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
123+
assert hf_output_ids == vllm_output_ids, (
124+
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
5050
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
5151
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
52+
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
5253
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
5354
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
5455
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),

0 commit comments

Comments
 (0)