11# SPDX-License-Identifier: Apache-2.0
22"""Minimal implementation of CLIPVisionModel intended to be only used
33within a vision language model."""
4- from typing import Iterable , List , Optional , Set , Tuple , Union
4+ from typing import Iterable , Optional , Set , Tuple , Union
55
6- import numpy as np
76import torch
87import torch .nn as nn
9- from PIL import Image
108from transformers import CLIPVisionConfig
119
1210from vllm .attention .layer import MultiHeadAttention
13- from vllm .config import ModelConfig
1411from vllm .distributed import divide , get_tensor_model_parallel_world_size
15- from vllm .inputs import DecoderOnlyInputs , token_inputs
1612from vllm .model_executor .layers .activation import get_act_fn
1713from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
1814 QKVParallelLinear ,
1915 RowParallelLinear )
2016from vllm .model_executor .layers .quantization import QuantizationConfig
2117from vllm .model_executor .model_loader .weight_utils import default_weight_loader
22- from vllm .multimodal .utils import (cached_get_tokenizer ,
23- consecutive_placeholder_ranges ,
24- repeat_and_pad_placeholder_tokens )
25- from vllm .sequence import SequenceData
2618
2719from .vision import VisionEncoderInfo , resolve_visual_encoder_outputs
2820
2921
30- def get_clip_patch_grid_length (* , image_size : int , patch_size : int ) -> int :
31- assert image_size % patch_size == 0
32- return image_size // patch_size
33-
34-
35- def get_clip_num_patches (* , image_size : int , patch_size : int ) -> int :
36- grid_length = get_clip_patch_grid_length (image_size = image_size ,
37- patch_size = patch_size )
38- return grid_length * grid_length
39-
40-
41- def get_clip_image_feature_size (hf_config : CLIPVisionConfig ) -> int :
42- return get_clip_num_patches (image_size = hf_config .image_size ,
43- patch_size = hf_config .patch_size ) + 1
44-
45-
46- def get_max_clip_image_tokens (hf_config : CLIPVisionConfig ) -> int :
47- return get_clip_image_feature_size (hf_config )
48-
49-
50- def dummy_seq_data_for_clip (hf_config : CLIPVisionConfig ,
51- seq_len : int ,
52- num_images : int ,
53- * ,
54- image_token_id : int ,
55- image_feature_size_override : Optional [int ] = None ,
56- mm_key : str = "image" ):
57- if image_feature_size_override is None :
58- image_feature_size = get_clip_image_feature_size (hf_config )
59- else :
60- image_feature_size = image_feature_size_override
61-
62- return SequenceData .from_prompt_token_counts (
63- (image_token_id , image_feature_size * num_images ),
64- (0 , seq_len - image_feature_size * num_images ),
65- ), {
66- mm_key :
67- consecutive_placeholder_ranges (num_items = num_images ,
68- item_size = image_feature_size )
69- }
70-
71-
72- def dummy_image_for_clip (
73- hf_config : CLIPVisionConfig ,
74- num_images : int ,
75- * ,
76- image_width_override : Optional [int ] = None ,
77- image_height_override : Optional [int ] = None ,
78- ):
79- width = height = hf_config .image_size
80- if image_width_override is not None :
81- width = image_width_override
82- if image_height_override is not None :
83- height = image_height_override
84-
85- image = Image .new ("RGB" , (width , height ), color = 0 )
86- return {"image" : image if num_images == 1 else [image ] * num_images }
87-
88-
89- def dummy_video_for_clip (
90- hf_config : CLIPVisionConfig ,
91- num_frames : int ,
92- num_videos : int = 1 ,
93- * ,
94- image_width_override : Optional [int ] = None ,
95- image_height_override : Optional [int ] = None ,
96- ):
97- pil_frame = dummy_image_for_clip (
98- hf_config ,
99- num_images = 1 ,
100- image_width_override = image_width_override ,
101- image_height_override = image_height_override )
102- np_frame = np .array (pil_frame ["image" ])
103- mm_data_per_video = np .repeat ([np_frame ], num_frames , axis = 0 )
104- video_data = [mm_data_per_video ] * num_videos
105- mm_data = {"video" : video_data }
106- return mm_data
107-
108-
109- def input_processor_for_clip (
110- model_config : ModelConfig ,
111- hf_config : CLIPVisionConfig ,
112- inputs : DecoderOnlyInputs ,
113- * ,
114- image_token_id : int ,
115- image_feature_size_override : Optional [Union [int , List [int ]]] = None ,
116- ):
117- multi_modal_data = inputs .get ("multi_modal_data" )
118- if multi_modal_data is None or "image" not in multi_modal_data :
119- return inputs
120-
121- if "multi_modal_placeholders" in inputs and "image" in inputs [
122- "multi_modal_placeholders" ]:
123- # The inputs already have placeholders.
124- return inputs
125-
126- tokenizer = cached_get_tokenizer (model_config .tokenizer )
127-
128- if image_feature_size_override is None :
129- image_data = multi_modal_data ["image" ]
130- if isinstance (image_data , Image .Image ):
131- image_feature_size = get_clip_image_feature_size (hf_config )
132- elif isinstance (image_data , torch .Tensor ):
133- num_images , image_feature_size , hidden_size = image_data .shape
134- else :
135- raise TypeError (f"Invalid image type: { type (image_data )} " )
136- else :
137- image_feature_size = image_feature_size_override
138-
139- new_prompt , new_token_ids , ranges = repeat_and_pad_placeholder_tokens (
140- tokenizer ,
141- inputs .get ("prompt" ),
142- inputs ["prompt_token_ids" ],
143- placeholder_token_id = image_token_id ,
144- repeat_count = image_feature_size ,
145- )
146-
147- # NOTE: Create a defensive copy of the original inputs
148- return token_inputs (prompt_token_ids = new_token_ids ,
149- prompt = new_prompt ,
150- multi_modal_data = multi_modal_data ,
151- multi_modal_placeholders = {"image" : ranges })
152-
153-
15422class CLIPEncoderInfo (VisionEncoderInfo [CLIPVisionConfig ]):
15523
15624 def get_num_image_tokens (
@@ -159,10 +27,10 @@ def get_num_image_tokens(
15927 image_width : int ,
16028 image_height : int ,
16129 ) -> int :
162- return get_clip_image_feature_size ( self .vision_config )
30+ return self .get_patch_grid_length () ** 2 + 1
16331
16432 def get_max_image_tokens (self ) -> int :
165- return get_max_clip_image_tokens ( self .vision_config )
33+ return self .get_patch_grid_length () ** 2 + 1
16634
16735 def get_image_size (self ) -> int :
16836 return self .vision_config .image_size
@@ -171,10 +39,9 @@ def get_patch_size(self) -> int:
17139 return self .vision_config .patch_size
17240
17341 def get_patch_grid_length (self ) -> int :
174- return get_clip_patch_grid_length (
175- image_size = self .vision_config .image_size ,
176- patch_size = self .vision_config .patch_size ,
177- )
42+ image_size , patch_size = self .get_image_size (), self .get_patch_size ()
43+ assert image_size % patch_size == 0
44+ return image_size // patch_size
17845
17946
18047# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
@@ -186,6 +53,7 @@ def __init__(self, config: CLIPVisionConfig):
18653 self .embed_dim = config .hidden_size
18754 self .image_size = config .image_size
18855 self .patch_size = config .patch_size
56+ assert self .image_size % self .patch_size == 0
18957
19058 self .class_embedding = nn .Parameter (torch .randn (self .embed_dim ))
19159
@@ -197,8 +65,7 @@ def __init__(self, config: CLIPVisionConfig):
19765 bias = False ,
19866 )
19967
200- self .num_patches = get_clip_num_patches (image_size = self .image_size ,
201- patch_size = self .patch_size )
68+ self .num_patches = (self .image_size // self .patch_size )** 2
20269 self .num_positions = self .num_patches + 1
20370 self .position_embedding = nn .Embedding (self .num_positions ,
20471 self .embed_dim )
0 commit comments