Skip to content

Commit 6ee62a7

Browse files
authored
Process list of list of images (#33465)
1 parent ea2d9fb commit 6ee62a7

File tree

2 files changed

+117
-50
lines changed

2 files changed

+117
-50
lines changed

src/transformers/models/pixtral/image_processing_pixtral.py

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Image processor class for Pixtral."""
1616

17-
from typing import Dict, List, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy as np
2020

@@ -31,6 +31,7 @@
3131
get_image_size,
3232
infer_channel_dimension_format,
3333
is_scaled_image,
34+
is_valid_image,
3435
make_list_of_images,
3536
to_numpy_array,
3637
valid_images,
@@ -48,7 +49,40 @@
4849
import PIL
4950

5051

51-
# Adapted from function in image_transforms.py t oensure any transparent pixels are converted to white.
52+
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
53+
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
54+
"""
55+
Convert a single image or a list of images to a list of numpy arrays.
56+
57+
Args:
58+
images (`ImageInput`):
59+
A single image or a list of images.
60+
61+
Returns:
62+
A list of numpy arrays.
63+
"""
64+
# If it's a single image, convert it to a list of lists
65+
if is_valid_image(images):
66+
images = [[images]]
67+
# If it's a list of images, it's a single batch, so convert it to a list of lists
68+
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
69+
images = [images]
70+
# If it's a list of batches, it's already in the right format
71+
elif (
72+
isinstance(images, (list, tuple))
73+
and len(images) > 0
74+
and isinstance(images[0], (list, tuple))
75+
and is_valid_image(images[0][0])
76+
):
77+
pass
78+
else:
79+
raise ValueError(
80+
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
81+
)
82+
return images
83+
84+
85+
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
5286
def convert_to_rgb(image: ImageInput) -> ImageInput:
5387
"""
5488
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
@@ -134,6 +168,18 @@ def get_resize_output_image_size(
134168
return num_height_tokens * patch_height, num_width_tokens * patch_width
135169

136170

171+
# Hack to get tensor conversion used in BatchFeature without batching the images
172+
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
173+
return BatchFeature()._get_is_as_tensor_fns(tensor_type)
174+
175+
176+
def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
177+
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
178+
if is_tensor(array):
179+
return array
180+
return as_tensor(array)
181+
182+
137183
class PixtralImageProcessor(BaseImageProcessor):
138184
r"""
139185
Constructs a Pixtral image processor.
@@ -333,11 +379,11 @@ def preprocess(
333379
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
334380
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
335381
"""
382+
patch_size = patch_size if patch_size is not None else self.patch_size
336383
patch_size = get_size_dict(patch_size, default_to_square=True)
337384

338385
do_resize = do_resize if do_resize is not None else self.do_resize
339386
size = size if size is not None else self.size
340-
patch_size = patch_size if patch_size is not None else self.patch_size
341387
resample = resample if resample is not None else self.resample
342388
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
343389
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
@@ -348,13 +394,14 @@ def preprocess(
348394

349395
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
350396

351-
images = make_list_of_images(images)
397+
images_list = make_list_of_images(images)
352398

353-
if not valid_images(images):
399+
if not valid_images(images_list[0][0]):
354400
raise ValueError(
355401
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
356402
"torch.Tensor, tf.Tensor or jax.ndarray."
357403
)
404+
358405
validate_preprocess_arguments(
359406
do_rescale=do_rescale,
360407
rescale_factor=rescale_factor,
@@ -367,46 +414,54 @@ def preprocess(
367414
)
368415

369416
if do_convert_rgb:
370-
images = [convert_to_rgb(image) for image in images]
417+
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
371418

372419
# All transformations expect numpy arrays.
373-
images = [to_numpy_array(image) for image in images]
420+
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
374421

375-
if is_scaled_image(images[0]) and do_rescale:
422+
if is_scaled_image(images_list[0][0]) and do_rescale:
376423
logger.warning_once(
377424
"It looks like you are trying to rescale already rescaled images. If the input"
378425
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
379426
)
380427

381428
if input_data_format is None:
382429
# We assume that all images have the same channel dimension format.
383-
input_data_format = infer_channel_dimension_format(images[0])
384-
385-
all_images = []
386-
for image in images:
387-
if do_resize:
388-
image = self.resize(
389-
image=image,
390-
size=size,
391-
patch_size=patch_size,
392-
resample=resample,
393-
input_data_format=input_data_format,
394-
)
395-
396-
if do_rescale:
397-
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
398-
399-
if do_normalize:
400-
image = self.normalize(
401-
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
402-
)
403-
404-
all_images.append(image)
405-
406-
images = [
407-
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
408-
for image in all_images
430+
input_data_format = infer_channel_dimension_format(images_list[0][0])
431+
432+
batch_images = []
433+
batch_image_sizes = []
434+
for sample_images in images_list:
435+
images = []
436+
image_sizes = []
437+
for image in sample_images:
438+
if do_resize:
439+
image = self.resize(
440+
image=image,
441+
size=size,
442+
patch_size=patch_size,
443+
resample=resample,
444+
input_data_format=input_data_format,
445+
)
446+
447+
if do_rescale:
448+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
449+
450+
if do_normalize:
451+
image = self.normalize(
452+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
453+
)
454+
455+
images.append(image)
456+
image_sizes.append(get_image_size(image, input_data_format))
457+
batch_images.append(images)
458+
batch_image_sizes.append(image_sizes)
459+
460+
images_list = [
461+
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
462+
for images in batch_images
409463
]
410464

411-
data = {"pixel_values": images}
412-
return BatchFeature(data=data, tensor_type=return_tensors)
465+
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
466+
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
467+
return BatchFeature(data={"images": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)

src/transformers/models/pixtral/processing_pixtral.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import List, Optional, Union
2020

2121
from ...feature_extraction_utils import BatchFeature
22-
from ...image_utils import ImageInput, get_image_size, to_numpy_array
22+
from ...image_utils import ImageInput
2323
from ...processing_utils import ProcessorMixin
2424
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
2525
from ...utils import TensorType, logging
@@ -146,21 +146,33 @@ def __call__(
146146

147147
# try to expand inputs in processing if we have the necessary parts
148148
prompt_strings = text
149-
if image_inputs.get("pixel_values") is not None:
149+
if image_inputs.get("images") is not None:
150150
# Replace the image token with the expanded image token sequence
151-
pixel_values = image_inputs["pixel_values"]
152-
height, width = get_image_size(to_numpy_array(pixel_values[0]))
153-
num_height_tokens = height // self.patch_size
154-
num_width_tokens = width // self.patch_size
155-
151+
images = image_inputs["images"]
152+
image_sizes = image_inputs.pop("image_sizes")
156153
prompt_strings = []
157-
replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
158-
# Flatten list
159-
replace_tokens = [item for sublist in replace_tokens for item in sublist]
160-
replace_tokens[-1] = self.image_end_token
161-
replace_str = "".join(replace_tokens)
162-
for sample in text:
163-
sample = sample.replace(self.image_token, replace_str)
154+
155+
for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
156+
replace_strings = []
157+
# First calculate the number of tokens needed for each image and put in a placeholder
158+
for image, image_size in zip(sample_images, sample_image_sizes):
159+
height, width = image_size
160+
num_height_tokens = height // self.patch_size
161+
num_width_tokens = width // self.patch_size
162+
replace_tokens = [
163+
[self.image_token] * num_width_tokens + [self.image_break_token]
164+
] * num_height_tokens
165+
# Flatten list
166+
replace_tokens = [item for sublist in replace_tokens for item in sublist]
167+
replace_tokens[-1] = self.image_end_token
168+
replace_str = "".join(replace_tokens)
169+
replace_strings.append(replace_str)
170+
sample = sample.replace(self.image_token, "<placeholder>", 1)
171+
172+
while "<placeholder>" in sample:
173+
replace_str = replace_strings.pop(0)
174+
sample = sample.replace("<placeholder>", replace_str, 1)
175+
164176
prompt_strings.append(sample)
165177

166178
text_inputs = self.tokenizer(

0 commit comments

Comments
 (0)