1414# limitations under the License.
1515"""Image processor class for Pixtral."""
1616
17- from typing import Dict , List , Optional , Tuple , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import numpy as np
2020
3131 get_image_size ,
3232 infer_channel_dimension_format ,
3333 is_scaled_image ,
34+ is_valid_image ,
3435 make_list_of_images ,
3536 to_numpy_array ,
3637 valid_images ,
4849 import PIL
4950
5051
51- # Adapted from function in image_transforms.py t oensure any transparent pixels are converted to white.
52+ # Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
53+ def make_list_of_images (images : ImageInput ) -> List [List [np .ndarray ]]:
54+ """
55+ Convert a single image or a list of images to a list of numpy arrays.
56+
57+ Args:
58+ images (`ImageInput`):
59+ A single image or a list of images.
60+
61+ Returns:
62+ A list of numpy arrays.
63+ """
64+ # If it's a single image, convert it to a list of lists
65+ if is_valid_image (images ):
66+ images = [[images ]]
67+ # If it's a list of images, it's a single batch, so convert it to a list of lists
68+ elif isinstance (images , (list , tuple )) and len (images ) > 0 and is_valid_image (images [0 ]):
69+ images = [images ]
70+ # If it's a list of batches, it's already in the right format
71+ elif (
72+ isinstance (images , (list , tuple ))
73+ and len (images ) > 0
74+ and isinstance (images [0 ], (list , tuple ))
75+ and is_valid_image (images [0 ][0 ])
76+ ):
77+ pass
78+ else :
79+ raise ValueError (
80+ "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
81+ )
82+ return images
83+
84+
85+ # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
5286def convert_to_rgb (image : ImageInput ) -> ImageInput :
5387 """
5488 Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
@@ -134,6 +168,18 @@ def get_resize_output_image_size(
134168 return num_height_tokens * patch_height , num_width_tokens * patch_width
135169
136170
171+ # Hack to get tensor conversion used in BatchFeature without batching the images
172+ def _get_is_as_tensor_fns (tensor_type : Union [str , TensorType ]) -> Tuple [Callable , Callable ]:
173+ return BatchFeature ()._get_is_as_tensor_fns (tensor_type )
174+
175+
176+ def convert_to_tensor (array , tensor_type : Union [str , TensorType ]) -> Any :
177+ is_tensor , as_tensor = _get_is_as_tensor_fns (tensor_type )
178+ if is_tensor (array ):
179+ return array
180+ return as_tensor (array )
181+
182+
137183class PixtralImageProcessor (BaseImageProcessor ):
138184 r"""
139185 Constructs a Pixtral image processor.
@@ -333,11 +379,11 @@ def preprocess(
333379 - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
334380 - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
335381 """
382+ patch_size = patch_size if patch_size is not None else self .patch_size
336383 patch_size = get_size_dict (patch_size , default_to_square = True )
337384
338385 do_resize = do_resize if do_resize is not None else self .do_resize
339386 size = size if size is not None else self .size
340- patch_size = patch_size if patch_size is not None else self .patch_size
341387 resample = resample if resample is not None else self .resample
342388 do_rescale = do_rescale if do_rescale is not None else self .do_rescale
343389 rescale_factor = rescale_factor if rescale_factor is not None else self .rescale_factor
@@ -348,13 +394,14 @@ def preprocess(
348394
349395 validate_kwargs (captured_kwargs = kwargs .keys (), valid_processor_keys = self ._valid_processor_keys )
350396
351- images = make_list_of_images (images )
397+ images_list = make_list_of_images (images )
352398
353- if not valid_images (images ):
399+ if not valid_images (images_list [ 0 ][ 0 ] ):
354400 raise ValueError (
355401 "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
356402 "torch.Tensor, tf.Tensor or jax.ndarray."
357403 )
404+
358405 validate_preprocess_arguments (
359406 do_rescale = do_rescale ,
360407 rescale_factor = rescale_factor ,
@@ -367,46 +414,54 @@ def preprocess(
367414 )
368415
369416 if do_convert_rgb :
370- images = [convert_to_rgb (image ) for image in images ]
417+ images_list = [[ convert_to_rgb (image ) for image in images ] for images in images_list ]
371418
372419 # All transformations expect numpy arrays.
373- images = [to_numpy_array (image ) for image in images ]
420+ images_list = [[ to_numpy_array (image ) for image in images ] for images in images_list ]
374421
375- if is_scaled_image (images [0 ]) and do_rescale :
422+ if is_scaled_image (images_list [ 0 ] [0 ]) and do_rescale :
376423 logger .warning_once (
377424 "It looks like you are trying to rescale already rescaled images. If the input"
378425 " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
379426 )
380427
381428 if input_data_format is None :
382429 # We assume that all images have the same channel dimension format.
383- input_data_format = infer_channel_dimension_format (images [0 ])
384-
385- all_images = []
386- for image in images :
387- if do_resize :
388- image = self .resize (
389- image = image ,
390- size = size ,
391- patch_size = patch_size ,
392- resample = resample ,
393- input_data_format = input_data_format ,
394- )
395-
396- if do_rescale :
397- image = self .rescale (image = image , scale = rescale_factor , input_data_format = input_data_format )
398-
399- if do_normalize :
400- image = self .normalize (
401- image = image , mean = image_mean , std = image_std , input_data_format = input_data_format
402- )
403-
404- all_images .append (image )
405-
406- images = [
407- to_channel_dimension_format (image , data_format , input_channel_dim = input_data_format )
408- for image in all_images
430+ input_data_format = infer_channel_dimension_format (images_list [0 ][0 ])
431+
432+ batch_images = []
433+ batch_image_sizes = []
434+ for sample_images in images_list :
435+ images = []
436+ image_sizes = []
437+ for image in sample_images :
438+ if do_resize :
439+ image = self .resize (
440+ image = image ,
441+ size = size ,
442+ patch_size = patch_size ,
443+ resample = resample ,
444+ input_data_format = input_data_format ,
445+ )
446+
447+ if do_rescale :
448+ image = self .rescale (image = image , scale = rescale_factor , input_data_format = input_data_format )
449+
450+ if do_normalize :
451+ image = self .normalize (
452+ image = image , mean = image_mean , std = image_std , input_data_format = input_data_format
453+ )
454+
455+ images .append (image )
456+ image_sizes .append (get_image_size (image , input_data_format ))
457+ batch_images .append (images )
458+ batch_image_sizes .append (image_sizes )
459+
460+ images_list = [
461+ [to_channel_dimension_format (image , data_format , input_channel_dim = input_data_format ) for image in images ]
462+ for images in batch_images
409463 ]
410464
411- data = {"pixel_values" : images }
412- return BatchFeature (data = data , tensor_type = return_tensors )
465+ # Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
466+ images_list = [[convert_to_tensor (image , return_tensors ) for image in images ] for images in images_list ]
467+ return BatchFeature (data = {"images" : images_list , "image_sizes" : batch_image_sizes }, tensor_type = None )
0 commit comments