@@ -909,29 +909,37 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
909909 elif self .model_type == 'pixtral' :
910910 # Hold on to pixel_values and input_ids.
911911 dtype = str_dtype_to_torch (self .vision_precision )
912- pixel_values = image ["pixel_values" ].to (device = "cuda" , dtype = dtype )
913- input_ids = image ["input_ids" ].to (device = "cuda" )
914-
915912 # Shape of pixel values from the processor varies with the raw image.
916913 # So we create a new tensor with a fixed shape as expected by the vision
917914 # encoder and create a corresponding attention mask.
918915 image_size = self .image_size
919916 patch_size = self .patch_size
920917 d_min = torch .finfo (dtype ).min
921918 num_patches = (image_size // patch_size )
922- image = torch .full ((1 , 3 , image_size , image_size ),
923- fill_value = 0 ,
924- dtype = dtype ,
925- device = "cuda" )
926- attention_mask = torch .full ((1 , num_patches , num_patches ),
927- fill_value = d_min ,
928- dtype = dtype ,
929- device = "cuda" )
930- h , w = pixel_values .shape [- 2 :]
931- image [..., :h , :w ] = pixel_values
932- attention_mask [..., :h // patch_size , :w // patch_size ] = 0
919+ padded_image = torch .full (
920+ (self .args .batch_size , 3 , image_size , image_size ),
921+ fill_value = 0 ,
922+ dtype = dtype ,
923+ device = "cuda" )
924+ padded_attention_mask = torch .full (
925+ (self .args .batch_size , num_patches , num_patches ),
926+ fill_value = d_min ,
927+ dtype = dtype ,
928+ device = "cuda" )
929+ h , w , input_ids = [], [], []
930+ for img_idx in range (self .args .batch_size ):
931+ pixel_values = image ["pixel_values" ][img_idx ]
932+ img_h , img_w = pixel_values .shape [- 2 :]
933+ padded_image [img_idx , :, :img_h , :img_w ] = pixel_values
934+ padded_attention_mask [img_idx , :img_h // patch_size , :img_w //
935+ patch_size ] = 0
936+ input_ids .append (image ["input_ids" ][img_idx ])
937+ h .append (img_h )
938+ w .append (img_w )
939+
940+ image = padded_image
933941 other_vision_inputs = {
934- "attention_mask" : attention_mask ,
942+ "attention_mask" : padded_attention_mask ,
935943 }
936944 elif self .model_type == 'llava_next' :
937945 input = image
@@ -1150,12 +1158,29 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
11501158 elif self .model_type == 'pixtral' :
11511159 relevant_patch_size = self .patch_size * self .spatial_merge_size
11521160 output_img_size = self .image_size // relevant_patch_size
1153- visual_features = visual_features .reshape (
1154- output_img_size , output_img_size ,
1155- - 1 )[:h // relevant_patch_size , :w //
1156- relevant_patch_size ].flatten (0 , 1 )
1161+ # Note: max_h * max_w shall serve as the `tokens_per_task` in ptuning prompt table.
1162+ max_h = max (h ) // relevant_patch_size
1163+ max_w = max (w ) // relevant_patch_size
1164+ visual_embed_dim = visual_features .shape [- 1 ]
1165+ relevant_visual_features = torch .zeros (self .args .batch_size ,
1166+ max_h * max_w ,
1167+ visual_embed_dim )
1168+ for img_idx in range (self .args .batch_size ):
1169+ complete_features = visual_features [img_idx ]
1170+ complete_features = complete_features .reshape (
1171+ output_img_size , output_img_size , visual_embed_dim )
1172+ relevant_h = h [img_idx ] // relevant_patch_size
1173+ relevant_w = w [img_idx ] // relevant_patch_size
1174+ flattened_features = complete_features [:relevant_h , :
1175+ relevant_w , :].flatten (
1176+ 0 , 1 )
1177+ relevant_visual_features [img_idx , :relevant_h *
1178+ relevant_w , :] = flattened_features
1179+ visual_features = relevant_visual_features
11571180 input_ids = self .ptuning_setup_pixtral (input_ids = input_ids )
1158- length = input_ids .shape [1 ]
1181+ # Note: length is not used for pixtral model downstream. Setting it to a list
1182+ # of length of input_ids causes errors downstream. So, supplying a placeholder.
1183+ length = input_ids [0 ].shape [0 ]
11591184
11601185 elif self .model_type == 'llava_next' :
11611186 visual_features = LlavaNextUtils .rearrange_image_features (
@@ -2027,16 +2052,19 @@ def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
20272052
20282053 def ptuning_setup_pixtral (self , input_ids ):
20292054 # input_ids obtained from processor has token_ids for text as well as image tokens
2030- # where each image token is represented the same image_token_index (10 for this model) .
2055+ # where each image token is represented by the same image_token_index.
20312056 image_token_index = self .image_token_index
20322057 vocab_size = self .vocab_size
20332058 # Replace all image tokens with a unique token_id > text_vacab_size.
20342059 # This shall be used to lookup the prompt table.
2035- replacer = vocab_size
2036- for i in range (len (input_ids [0 ])):
2037- if input_ids [0 ][i ] == image_token_index :
2038- input_ids [0 ][i ] = replacer
2039- replacer += 1
2060+ for img_idx in range (self .args .batch_size ):
2061+ # Note: We reset replacer to text_vocab_size for each sample. This is as opposed to doing `replacer = vocab_size + img_idx * tokens_per_task`.
2062+ # That part of the look-up manipulation is done by the `task_ids` input to PromptEmbedding forward.
2063+ replacer = vocab_size
2064+ for token_idx in range (len (input_ids [img_idx ])):
2065+ if input_ids [img_idx ][token_idx ] == image_token_index :
2066+ input_ids [img_idx ][token_idx ] = replacer
2067+ replacer += 1
20402068 return input_ids
20412069
20422070 def ptuning_setup_llava_next (self , visual_features , pre_prompt ,
@@ -2166,7 +2194,24 @@ def load_images(image_paths):
21662194 if isinstance (image_path , str ):
21672195 image_path = image_path .split (self .args .path_sep )
21682196 images = load_images (image_path )
2169-
2197+ elif "pixtral" in self .model_type :
2198+ if image_path is None :
2199+ image_urls = [
2200+ "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png" ,
2201+ "https://www.ilankelman.org/stopsigns/australia.jpg" ,
2202+ "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png" ,
2203+ "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" ,
2204+ ]
2205+ while len (image_urls ) < self .args .batch_size :
2206+ image_urls *= 2
2207+ image_urls = image_urls [:self .args .batch_size ]
2208+ self .args .image_path = "," .join (image_urls )
2209+ images = load_images (image_urls )
2210+ else :
2211+ if isinstance (image_path , str ):
2212+ image_path = image_path .split (self .args .path_sep )
2213+ images = load_images (image_path )
2214+ images = [images ] if not isinstance (images , list ) else images
21702215 elif "nougat" in self .model_type :
21712216 filepath = hf_hub_download (
21722217 repo_id = "hf-internal-testing/fixtures_docvqa" ,
@@ -2413,9 +2458,15 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
24132458 post_prompt = "[/INST]"
24142459 prompt = pre_prompt + input_text + post_prompt
24152460 dtype = str_dtype_to_torch (self .vision_precision )
2416- image = self .processor (text = prompt ,
2417- images = [raw_image ],
2418- return_tensors = "pt" ).to (dtype )
2461+ image = {'pixel_values' : [], 'input_ids' : []}
2462+ for img_idx in range (self .args .batch_size ):
2463+ image_info = self .processor (text = prompt ,
2464+ images = [raw_image [img_idx ]],
2465+ return_tensors = "pt" ).to (dtype )
2466+ image ['pixel_values' ].append (image_info ['pixel_values' ].to (
2467+ self .device ))
2468+ image ['input_ids' ].append (image_info ['input_ids' ][0 ].to (
2469+ self .device ))
24192470
24202471 elif 'internvl' in self .model_type :
24212472 pre_prompt = "<|system|>\n 你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|end|><|user|>\n <image>\n "
@@ -2619,7 +2670,7 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
26192670 image = image .expand (
26202671 min (self .args .batch_size , len (input_text )), - 1 , - 1 ,
26212672 - 1 ).contiguous ()
2622- if image is not None :
2673+ if image is not None and isinstance ( image , torch . Tensor ) :
26232674 image = image .to (self .device )
26242675 # Generate decoder_input_ids for enc-dec models
26252676 # Custom prompts can be added as:
0 commit comments