Skip to content

Commit f3f6c86

Browse files
ShuaiBai623minostaurosstevhliugewenbin0992
authored
add qwen2.5vl (#35569)
* add qwen2.5vl * fix * pass check table * add modular file * fix style * Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <[email protected]> * Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <[email protected]> * Update src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Co-authored-by: Minho Shim <[email protected]> * padd copy check * use modular * fix * fix * fix * update flashatt2&sdpa support_list * Update docs/source/en/_toctree.yml Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/model_doc/qwen2_5_vl.md Co-authored-by: Steven Liu <[email protected]> * Update src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py Co-authored-by: Steven Liu <[email protected]> * update config * update * fix hf path * rename Qwen2_5_VLVideosKwargs * fix * fix * update * excuted modular * rollback init * fix * formated * simpler init * fix * fix * fix * fix * fix * update docs * fix * fix * update Qwen2VLRotaryEmbedding for yarn * fix --------- Co-authored-by: Minho Shim <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: gewenbin0992 <[email protected]> Co-authored-by: gewenbin0992 <[email protected]>
1 parent d3af76d commit f3f6c86

25 files changed

+5184
-44
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,8 @@
928928
title: Pix2Struct
929929
- local: model_doc/pixtral
930930
title: Pixtral
931+
- local: model_doc/qwen2_5_vl
932+
title: Qwen2.5-VL
931933
- local: model_doc/qwen2_audio
932934
title: Qwen2Audio
933935
- local: model_doc/qwen2_vl

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ Flax), PyTorch, and/or TensorFlow.
285285
| [PVTv2](model_doc/pvt_v2) ||||
286286
| [QDQBert](model_doc/qdqbert) ||||
287287
| [Qwen2](model_doc/qwen2) ||||
288+
| [Qwen2_5_VL](model_doc/qwen2_5_vl) ||||
288289
| [Qwen2Audio](model_doc/qwen2_audio) ||||
289290
| [Qwen2MoE](model_doc/qwen2_moe) ||||
290291
| [Qwen2VL](model_doc/qwen2_vl) ||||
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
<!--Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Qwen2.5-VL
18+
19+
## Overview
20+
21+
The [Qwen2.5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group.
22+
23+
The abstract from this update is the following:
24+
25+
*Qwen2.5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.*
26+
27+
## Usage example
28+
29+
### Single Media inference
30+
31+
The model can accept both images and videos as input. Here's an example code for inference.
32+
33+
```python
34+
35+
from PIL import Image
36+
import requests
37+
import torch
38+
from torchvision import io
39+
from typing import Dict
40+
from transformers.image_utils import load_images, load_video
41+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
42+
43+
# Load the model in half-precision on the available device(s)
44+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto")
45+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
46+
47+
# Image
48+
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
49+
image = Image.open(requests.get(url, stream=True).raw)
50+
51+
conversation = [
52+
{
53+
"role":"user",
54+
"content":[
55+
{
56+
"type":"image",
57+
},
58+
{
59+
"type":"text",
60+
"text":"Describe this image."
61+
}
62+
]
63+
}
64+
]
65+
66+
67+
# Preprocess the inputs
68+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
69+
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\n'
70+
71+
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
72+
inputs = inputs.to('cuda')
73+
74+
# Inference: Generation of the output
75+
output_ids = model.generate(**inputs, max_new_tokens=128)
76+
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
77+
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
78+
print(output_text)
79+
80+
# Video
81+
video = load_video(video="/path/to/video.mp4")
82+
conversation = [
83+
{
84+
"role": "user",
85+
"content": [
86+
{"type": "video"},
87+
{"type": "text", "text": "What happened in the video?"},
88+
],
89+
}
90+
]
91+
92+
# Preprocess the inputs
93+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
94+
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>What happened in the video?<|im_end|>\n<|im_start|>assistant\n'
95+
96+
# Qwen2.5VL modifies the time positional encoding (MRoPE) according to the video's frame rate (FPS).
97+
# Therefore, the video's FPS information needs to be provided as input.
98+
inputs = processor(text=[text_prompt], videos=[video], fps=[1.0], padding=True, return_tensors="pt")
99+
inputs = inputs.to('cuda')
100+
101+
# Inference: Generation of the output
102+
output_ids = model.generate(**inputs, max_new_tokens=128)
103+
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
104+
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
105+
print(output_text)
106+
```
107+
108+
### Batch Mixed Media Inference
109+
110+
The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example.
111+
112+
```python
113+
images = load_images([
114+
"/path/to/image1.jpg",
115+
"/path/to/image2.jpg",
116+
"/path/to/image3.jpg",
117+
"/path/to/image4.jpg",
118+
"/path/to/image5.jpg",
119+
])
120+
video = load_video(video="/path/to/video.mp4")
121+
122+
# Conversation for the first image
123+
conversation1 = [
124+
{
125+
"role": "user",
126+
"content": [
127+
{"type": "image"},
128+
{"type": "text", "text": "Describe this image."}
129+
]
130+
}
131+
]
132+
133+
# Conversation with two images
134+
conversation2 = [
135+
{
136+
"role": "user",
137+
"content": [
138+
{"type": "image"},
139+
{"type": "image"},
140+
{"type": "text", "text": "What is written in the pictures?"}
141+
]
142+
}
143+
]
144+
145+
# Conversation with pure text
146+
conversation3 = [
147+
{
148+
"role": "user",
149+
"content": "who are you?"
150+
}
151+
]
152+
153+
154+
# Conversation with mixed midia
155+
conversation4 = [
156+
{
157+
"role": "user",
158+
"content": [
159+
{"type": "image"},
160+
{"type": "image"},
161+
{"type": "video"},
162+
{"type": "text", "text": "What are the common elements in these medias?"},
163+
],
164+
}
165+
]
166+
167+
conversations = [conversation1, conversation2, conversation3, conversation4]
168+
# Preparation for batch inference
169+
texts = [processor.apply_chat_template(msg, add_generation_prompt=True) for msg in conversations]
170+
inputs = processor(
171+
text=texts,
172+
images=images,
173+
videos=[video],
174+
padding=True,
175+
return_tensors="pt",
176+
)
177+
inputs = inputs.to('cuda')
178+
179+
# Batch Inference
180+
output_ids = model.generate(**inputs, max_new_tokens=128)
181+
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
182+
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
183+
print(output_text)
184+
```
185+
186+
### Usage Tips
187+
188+
#### Image Resolution trade-off
189+
190+
The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs.
191+
192+
```python
193+
min_pixels = 224*224
194+
max_pixels = 2048*2048
195+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
196+
```
197+
198+
In case of limited GPU RAM, one can reduce the resolution as follows:
199+
200+
```python
201+
min_pixels = 256*28*28
202+
max_pixels = 1024*28*28
203+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
204+
```
205+
This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28).
206+
207+
#### Multiple Image Inputs
208+
209+
By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings:
210+
211+
```python
212+
conversation = [
213+
{
214+
"role": "user",
215+
"content": [
216+
{"type": "image"},
217+
{"type": "text", "text": "Hello, how are you?"}
218+
]
219+
},
220+
{
221+
"role": "assistant",
222+
"content": "I'm doing well, thank you for asking. How can I assist you today?"
223+
},
224+
{
225+
"role": "user",
226+
"content": [
227+
{"type": "text", "text": "Can you describe these images and video?"},
228+
{"type": "image"},
229+
{"type": "image"},
230+
{"type": "video"},
231+
{"type": "text", "text": "These are from my vacation."}
232+
]
233+
},
234+
{
235+
"role": "assistant",
236+
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
237+
},
238+
{
239+
"role": "user",
240+
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
241+
}
242+
]
243+
244+
# default:
245+
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
246+
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
247+
248+
249+
# add ids
250+
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
251+
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
252+
253+
```
254+
255+
#### Flash-Attention 2 to speed up generation
256+
257+
First, make sure to install the latest version of Flash Attention 2:
258+
259+
```bash
260+
pip install -U flash-attn --no-build-isolation
261+
```
262+
263+
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
264+
265+
To load and run a model using FlashAttention-2, add `attn_implementation="flash_attention_2"` when loading the model:
266+
267+
```python
268+
from transformers import Qwen2_5_VLForConditionalGeneration
269+
270+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
271+
"Qwen/Qwen2.5-VL-7B-Instruct",
272+
torch_dtype=torch.bfloat16,
273+
attn_implementation="flash_attention_2",
274+
)
275+
```
276+
277+
278+
279+
## Qwen2_5_VLConfig
280+
281+
[[autodoc]] Qwen2_5_VLConfig
282+
283+
## Qwen2_5_VLImageProcessor
284+
285+
[[autodoc]] Qwen2_5_VLImageProcessor
286+
- preprocess
287+
288+
## Qwen2_5_VLProcessor
289+
290+
[[autodoc]] Qwen2_5_VLProcessor
291+
292+
## Qwen2_5_VLModel
293+
294+
[[autodoc]] Qwen2_5_VLModel
295+
- forward
296+
297+
## Qwen2_5_VLForConditionalGeneration
298+
299+
[[autodoc]] Qwen2_5_VLForConditionalGeneration
300+
- forward

docs/source/en/perf_infer_gpu_one.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ FlashAttention-2 is currently supported for the following architectures:
9797
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
9898
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
9999
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
100+
* [Qwen2.5VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl#transformers.Qwen2_5_VLModel)
100101
* [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel)
101102
* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel)
102103
* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel)
@@ -297,6 +298,7 @@ For now, Transformers supports SDPA inference and training for the following arc
297298
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
298299
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
299300
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
301+
* [Qwen2.5VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl#transformers.Qwen2_5_VLModel)
300302
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
301303
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
302304
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)

src/transformers/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,10 @@
708708
"Qwen2Config",
709709
"Qwen2Tokenizer",
710710
],
711+
"models.qwen2_5_vl": [
712+
"Qwen2_5_VLConfig",
713+
"Qwen2_5_VLProcessor",
714+
],
711715
"models.qwen2_audio": [
712716
"Qwen2AudioConfig",
713717
"Qwen2AudioEncoderConfig",
@@ -1263,6 +1267,7 @@
12631267
_import_structure["models.pixtral"].append("PixtralImageProcessor")
12641268
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
12651269
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
1270+
_import_structure["models.qwen2_5_vl"].extend(["Qwen2_5_VLImageProcessor"])
12661271
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
12671272
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
12681273
_import_structure["models.sam"].extend(["SamImageProcessor"])
@@ -3276,6 +3281,13 @@
32763281
"Qwen2PreTrainedModel",
32773282
]
32783283
)
3284+
_import_structure["models.qwen2_5_vl"].extend(
3285+
[
3286+
"Qwen2_5_VLForConditionalGeneration",
3287+
"Qwen2_5_VLModel",
3288+
"Qwen2_5_VLPreTrainedModel",
3289+
]
3290+
)
32793291
_import_structure["models.qwen2_audio"].extend(
32803292
[
32813293
"Qwen2AudioEncoder",
@@ -5783,6 +5795,10 @@
57835795
from .models.pvt import PvtConfig
57845796
from .models.pvt_v2 import PvtV2Config
57855797
from .models.qwen2 import Qwen2Config, Qwen2Tokenizer
5798+
from .models.qwen2_5_vl import (
5799+
Qwen2_5_VLConfig,
5800+
Qwen2_5_VLProcessor,
5801+
)
57865802
from .models.qwen2_audio import (
57875803
Qwen2AudioConfig,
57885804
Qwen2AudioEncoderConfig,
@@ -6362,6 +6378,7 @@
63626378
PoolFormerImageProcessor,
63636379
)
63646380
from .models.pvt import PvtImageProcessor
6381+
from .models.qwen2_5_vl import Qwen2_5_VLImageProcessor
63656382
from .models.qwen2_vl import Qwen2VLImageProcessor
63666383
from .models.rt_detr import RTDetrImageProcessor
63676384
from .models.sam import SamImageProcessor
@@ -7980,6 +7997,11 @@
79807997
Qwen2Model,
79817998
Qwen2PreTrainedModel,
79827999
)
8000+
from .models.qwen2_5_vl import (
8001+
Qwen2_5_VLForConditionalGeneration,
8002+
Qwen2_5_VLModel,
8003+
Qwen2_5_VLPreTrainedModel,
8004+
)
79838005
from .models.qwen2_audio import (
79848006
Qwen2AudioEncoder,
79858007
Qwen2AudioForConditionalGeneration,

0 commit comments

Comments
 (0)