Skip to content

Commit 3710510

Browse files
Isotr0pyshreyankg
authored andcommitted
[VLM] Remove input processor from clip and siglip (vllm-project#13165)
1 parent 9b55be4 commit 3710510

File tree

2 files changed

+10
-213
lines changed

2 files changed

+10
-213
lines changed

vllm/model_executor/models/clip.py

Lines changed: 8 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Minimal implementation of CLIPVisionModel intended to be only used
33
within a vision language model."""
4-
from typing import Iterable, List, Optional, Set, Tuple, Union
4+
from typing import Iterable, Optional, Set, Tuple, Union
55

6-
import numpy as np
76
import torch
87
import torch.nn as nn
9-
from PIL import Image
108
from transformers import CLIPVisionConfig
119

1210
from vllm.attention.layer import MultiHeadAttention
13-
from vllm.config import ModelConfig
1411
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15-
from vllm.inputs import DecoderOnlyInputs, token_inputs
1612
from vllm.model_executor.layers.activation import get_act_fn
1713
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1814
QKVParallelLinear,
1915
RowParallelLinear)
2016
from vllm.model_executor.layers.quantization import QuantizationConfig
2117
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22-
from vllm.multimodal.utils import (cached_get_tokenizer,
23-
consecutive_placeholder_ranges,
24-
repeat_and_pad_placeholder_tokens)
25-
from vllm.sequence import SequenceData
2618

2719
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
2820

2921

30-
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
31-
assert image_size % patch_size == 0
32-
return image_size // patch_size
33-
34-
35-
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
36-
grid_length = get_clip_patch_grid_length(image_size=image_size,
37-
patch_size=patch_size)
38-
return grid_length * grid_length
39-
40-
41-
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
42-
return get_clip_num_patches(image_size=hf_config.image_size,
43-
patch_size=hf_config.patch_size) + 1
44-
45-
46-
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
47-
return get_clip_image_feature_size(hf_config)
48-
49-
50-
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
51-
seq_len: int,
52-
num_images: int,
53-
*,
54-
image_token_id: int,
55-
image_feature_size_override: Optional[int] = None,
56-
mm_key: str = "image"):
57-
if image_feature_size_override is None:
58-
image_feature_size = get_clip_image_feature_size(hf_config)
59-
else:
60-
image_feature_size = image_feature_size_override
61-
62-
return SequenceData.from_prompt_token_counts(
63-
(image_token_id, image_feature_size * num_images),
64-
(0, seq_len - image_feature_size * num_images),
65-
), {
66-
mm_key:
67-
consecutive_placeholder_ranges(num_items=num_images,
68-
item_size=image_feature_size)
69-
}
70-
71-
72-
def dummy_image_for_clip(
73-
hf_config: CLIPVisionConfig,
74-
num_images: int,
75-
*,
76-
image_width_override: Optional[int] = None,
77-
image_height_override: Optional[int] = None,
78-
):
79-
width = height = hf_config.image_size
80-
if image_width_override is not None:
81-
width = image_width_override
82-
if image_height_override is not None:
83-
height = image_height_override
84-
85-
image = Image.new("RGB", (width, height), color=0)
86-
return {"image": image if num_images == 1 else [image] * num_images}
87-
88-
89-
def dummy_video_for_clip(
90-
hf_config: CLIPVisionConfig,
91-
num_frames: int,
92-
num_videos: int = 1,
93-
*,
94-
image_width_override: Optional[int] = None,
95-
image_height_override: Optional[int] = None,
96-
):
97-
pil_frame = dummy_image_for_clip(
98-
hf_config,
99-
num_images=1,
100-
image_width_override=image_width_override,
101-
image_height_override=image_height_override)
102-
np_frame = np.array(pil_frame["image"])
103-
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
104-
video_data = [mm_data_per_video] * num_videos
105-
mm_data = {"video": video_data}
106-
return mm_data
107-
108-
109-
def input_processor_for_clip(
110-
model_config: ModelConfig,
111-
hf_config: CLIPVisionConfig,
112-
inputs: DecoderOnlyInputs,
113-
*,
114-
image_token_id: int,
115-
image_feature_size_override: Optional[Union[int, List[int]]] = None,
116-
):
117-
multi_modal_data = inputs.get("multi_modal_data")
118-
if multi_modal_data is None or "image" not in multi_modal_data:
119-
return inputs
120-
121-
if "multi_modal_placeholders" in inputs and "image" in inputs[
122-
"multi_modal_placeholders"]:
123-
# The inputs already have placeholders.
124-
return inputs
125-
126-
tokenizer = cached_get_tokenizer(model_config.tokenizer)
127-
128-
if image_feature_size_override is None:
129-
image_data = multi_modal_data["image"]
130-
if isinstance(image_data, Image.Image):
131-
image_feature_size = get_clip_image_feature_size(hf_config)
132-
elif isinstance(image_data, torch.Tensor):
133-
num_images, image_feature_size, hidden_size = image_data.shape
134-
else:
135-
raise TypeError(f"Invalid image type: {type(image_data)}")
136-
else:
137-
image_feature_size = image_feature_size_override
138-
139-
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
140-
tokenizer,
141-
inputs.get("prompt"),
142-
inputs["prompt_token_ids"],
143-
placeholder_token_id=image_token_id,
144-
repeat_count=image_feature_size,
145-
)
146-
147-
# NOTE: Create a defensive copy of the original inputs
148-
return token_inputs(prompt_token_ids=new_token_ids,
149-
prompt=new_prompt,
150-
multi_modal_data=multi_modal_data,
151-
multi_modal_placeholders={"image": ranges})
152-
153-
15422
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
15523

15624
def get_num_image_tokens(
@@ -159,10 +27,10 @@ def get_num_image_tokens(
15927
image_width: int,
16028
image_height: int,
16129
) -> int:
162-
return get_clip_image_feature_size(self.vision_config)
30+
return self.get_patch_grid_length()**2 + 1
16331

16432
def get_max_image_tokens(self) -> int:
165-
return get_max_clip_image_tokens(self.vision_config)
33+
return self.get_patch_grid_length()**2 + 1
16634

16735
def get_image_size(self) -> int:
16836
return self.vision_config.image_size
@@ -171,10 +39,9 @@ def get_patch_size(self) -> int:
17139
return self.vision_config.patch_size
17240

17341
def get_patch_grid_length(self) -> int:
174-
return get_clip_patch_grid_length(
175-
image_size=self.vision_config.image_size,
176-
patch_size=self.vision_config.patch_size,
177-
)
42+
image_size, patch_size = self.get_image_size(), self.get_patch_size()
43+
assert image_size % patch_size == 0
44+
return image_size // patch_size
17845

17946

18047
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
@@ -186,6 +53,7 @@ def __init__(self, config: CLIPVisionConfig):
18653
self.embed_dim = config.hidden_size
18754
self.image_size = config.image_size
18855
self.patch_size = config.patch_size
56+
assert self.image_size % self.patch_size == 0
18957

19058
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
19159

@@ -197,8 +65,7 @@ def __init__(self, config: CLIPVisionConfig):
19765
bias=False,
19866
)
19967

200-
self.num_patches = get_clip_num_patches(image_size=self.image_size,
201-
patch_size=self.patch_size)
68+
self.num_patches = (self.image_size // self.patch_size)**2
20269
self.num_positions = self.num_patches + 1
20370
self.position_embedding = nn.Embedding(self.num_positions,
20471
self.embed_dim)

vllm/model_executor/models/siglip.py

Lines changed: 2 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
within a vision language model."""
44

55
import math
6-
from typing import Iterable, List, Optional, Set, Tuple, Union
6+
from typing import Iterable, Optional, Set, Tuple, Union
77

8-
import numpy as np
98
import torch
109
from PIL import Image
1110
from torch import nn
1211
from transformers import SiglipVisionConfig
1312

1413
from vllm.attention.layer import MultiHeadAttention
15-
from vllm.config import ModelConfig
1614
from vllm.distributed import divide, get_tensor_model_parallel_world_size
17-
from vllm.inputs import DecoderOnlyInputs, token_inputs
1815
from vllm.model_executor.layers.activation import get_act_fn
1916
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2017
QKVParallelLinear,
@@ -23,9 +20,7 @@
2320
from vllm.model_executor.layers.vocab_parallel_embedding import (
2421
VocabParallelEmbedding)
2522
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26-
from vllm.multimodal.utils import (cached_get_tokenizer,
27-
consecutive_placeholder_ranges,
28-
repeat_and_pad_placeholder_tokens)
23+
from vllm.multimodal.utils import consecutive_placeholder_ranges
2924
from vllm.sequence import SequenceData
3025

3126
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@@ -93,71 +88,6 @@ def dummy_image_for_siglip(
9388
return {"image": image if num_images == 1 else [image] * num_images}
9489

9590

96-
def dummy_video_for_siglip(
97-
hf_config: SiglipVisionConfig,
98-
num_frames: int,
99-
num_videos: int = 1,
100-
*,
101-
image_width_override: Optional[int] = None,
102-
image_height_override: Optional[int] = None,
103-
):
104-
pil_frame = dummy_image_for_siglip(
105-
hf_config,
106-
num_images=1,
107-
image_width_override=image_width_override,
108-
image_height_override=image_height_override)
109-
np_frame = np.array(pil_frame["image"])
110-
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
111-
video_data = [mm_data_per_video] * num_videos
112-
mm_data = {"video": video_data}
113-
return mm_data
114-
115-
116-
def input_processor_for_siglip(
117-
model_config: ModelConfig,
118-
hf_config: SiglipVisionConfig,
119-
inputs: DecoderOnlyInputs,
120-
*,
121-
image_token_id: int,
122-
image_feature_size_override: Optional[Union[int, List[int]]] = None,
123-
):
124-
multi_modal_data = inputs.get("multi_modal_data")
125-
if multi_modal_data is None or "image" not in multi_modal_data:
126-
return inputs
127-
128-
if "multi_modal_placeholders" in inputs and "image" in inputs[
129-
"multi_modal_placeholders"]:
130-
# The inputs already have placeholders.
131-
return inputs
132-
133-
tokenizer = cached_get_tokenizer(model_config.tokenizer)
134-
135-
if image_feature_size_override is None:
136-
image_data = multi_modal_data["image"]
137-
if isinstance(image_data, Image.Image):
138-
image_feature_size = get_siglip_image_feature_size(hf_config)
139-
elif isinstance(image_data, torch.Tensor):
140-
num_images, image_feature_size, hidden_size = image_data.shape
141-
else:
142-
raise TypeError(f"Invalid image type: {type(image_data)}")
143-
else:
144-
image_feature_size = image_feature_size_override
145-
146-
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
147-
tokenizer,
148-
inputs.get("prompt"),
149-
inputs["prompt_token_ids"],
150-
placeholder_token_id=image_token_id,
151-
repeat_count=image_feature_size,
152-
)
153-
154-
# NOTE: Create a defensive copy of the original inputs
155-
return token_inputs(prompt_token_ids=new_token_ids,
156-
prompt=new_prompt,
157-
multi_modal_data=multi_modal_data,
158-
multi_modal_placeholders={"image": ranges})
159-
160-
16191
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
16292

16393
def get_num_image_tokens(

0 commit comments

Comments
 (0)