Skip to content

Commit b1f61bd

Browse files
committed
Add the extra processor and tests
1 parent e288c68 commit b1f61bd

File tree

5 files changed

+1140
-10
lines changed

5 files changed

+1140
-10
lines changed

src/transformers/models/sam/image_processing_sam.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,18 +454,21 @@ def post_process_masks_tf(
454454
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
455455
is given by original_size.
456456
"""
457-
requires_backends(self, ["tensorflow"])
457+
requires_backends(self, ["tf"])
458458
pad_size = self.pad_size if pad_size is None else pad_size
459459
target_image_size = (pad_size["height"], pad_size["width"])
460460

461461
output_masks = []
462462
for i, original_size in enumerate(original_sizes):
463-
interpolated_mask = tf.image.resize(masks[i], target_image_size, method="bilinear")
463+
# tf.image expects NHWC, we transpose the NCHW inputs for it
464+
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
465+
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
464466
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
465467
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
466468
if binarize:
467469
interpolated_mask = interpolated_mask > mask_threshold
468-
output_masks.append(interpolated_mask)
470+
# And then we transpose them back at the end
471+
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
469472

470473
return output_masks
471474

@@ -684,7 +687,7 @@ def filter_masks_tf(
684687
The offset for the stability score used in the `_compute_stability_score` method.
685688
686689
"""
687-
requires_backends(self, ["tensorflow"])
690+
requires_backends(self, ["tf"])
688691
original_height, original_width = original_size
689692
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
690693
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])

src/transformers/models/sam/processing_sam.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222

2323
from ...processing_utils import ProcessorMixin
2424
from ...tokenization_utils_base import BatchEncoding
25-
from ...utils import TensorType, is_torch_available
25+
from ...utils import TensorType, is_torch_available, is_tf_available
2626

2727

2828
if is_torch_available():
2929
import torch
3030

31+
if is_tf_available():
32+
import tensorflow as tf
33+
3134

3235
class SamProcessor(ProcessorMixin):
3336
r"""
@@ -72,7 +75,7 @@ def __call__(
7275
# pop arguments that are not used in the foward but used nevertheless
7376
original_sizes = encoding_image_processor["original_sizes"]
7477

75-
if isinstance(original_sizes, torch.Tensor):
78+
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
7679
original_sizes = original_sizes.numpy()
7780

7881
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
@@ -139,18 +142,30 @@ def _normalize_and_convert(
139142
input_boxes = torch.from_numpy(input_boxes)
140143
# boxes batch size of 1 by default
141144
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
145+
elif return_tensors == "tf":
146+
input_boxes = tf.convert_to_tensor(input_boxes)
147+
# boxes batch size of 1 by default
148+
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
142149
encoding_image_processor.update({"input_boxes": input_boxes})
143150
if input_points is not None:
144151
if return_tensors == "pt":
145152
input_points = torch.from_numpy(input_points)
146153
# point batch size of 1 by default
147154
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
155+
elif return_tensors == "tf":
156+
input_points = tf.convert_to_tensor(input_points)
157+
# point batch size of 1 by default
158+
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
148159
encoding_image_processor.update({"input_points": input_points})
149160
if input_labels is not None:
150161
if return_tensors == "pt":
151162
input_labels = torch.from_numpy(input_labels)
152163
# point batch size of 1 by default
153164
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
165+
elif return_tensors == "tf":
166+
input_labels = tf.convert_to_tensor(input_labels)
167+
# point batch size of 1 by default
168+
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
154169
encoding_image_processor.update({"input_labels": input_labels})
155170

156171
return encoding_image_processor
@@ -204,7 +219,7 @@ def _check_and_preprocess_points(
204219
it is converted to a `numpy.ndarray` and then to a `list`.
205220
"""
206221
if input_points is not None:
207-
if isinstance(input_points, torch.Tensor):
222+
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
208223
input_points = input_points.numpy().tolist()
209224

210225
if not isinstance(input_points, list) and not isinstance(input_points[0], list):
@@ -214,7 +229,7 @@ def _check_and_preprocess_points(
214229
input_points = None
215230

216231
if input_labels is not None:
217-
if isinstance(input_labels, torch.Tensor):
232+
if hasattr(input_labels, "numpy"):
218233
input_labels = input_labels.numpy().tolist()
219234

220235
if not isinstance(input_labels, list) and not isinstance(input_labels[0], list):
@@ -224,7 +239,7 @@ def _check_and_preprocess_points(
224239
input_labels = None
225240

226241
if input_boxes is not None:
227-
if isinstance(input_boxes, torch.Tensor):
242+
if hasattr(input_boxes, "numpy"):
228243
input_boxes = input_boxes.numpy().tolist()
229244

230245
if (
@@ -245,4 +260,8 @@ def model_input_names(self):
245260
return list(dict.fromkeys(image_processor_input_names))
246261

247262
def post_process_masks(self, *args, **kwargs):
248-
return self.image_processor.post_process_masks(*args, **kwargs)
263+
return_tensors = kwargs.pop("return_tensors", "pt")
264+
if return_tensors == "pt":
265+
return self.image_processor.post_process_masks(*args, **kwargs)
266+
elif return_tensors == "tf":
267+
return self.image_processor.post_process_masks_tf(*args, **kwargs)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# coding=utf-8
2+
# Copyright 2023 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Processor class for SAM.
17+
"""
18+
from copy import deepcopy
19+
from typing import Optional, Union
20+
21+
import numpy as np
22+
23+
from ...processing_utils import ProcessorMixin
24+
from ...tokenization_utils_base import BatchEncoding
25+
from ...utils import TensorType, is_torch_available
26+
27+
28+
if is_torch_available():
29+
import torch
30+
31+
32+
class SamProcessor(ProcessorMixin):
33+
r"""
34+
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
35+
single processor.
36+
37+
[`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
38+
[`~SamImageProcessor.__call__`] for more information.
39+
40+
Args:
41+
image_processor (`SamImageProcessor`):
42+
An instance of [`SamImageProcessor`]. The image processor is a required input.
43+
"""
44+
attributes = ["image_processor"]
45+
image_processor_class = "SamImageProcessor"
46+
47+
def __init__(self, image_processor):
48+
super().__init__(image_processor)
49+
self.current_processor = self.image_processor
50+
self.point_pad_value = -10
51+
self.target_size = self.image_processor.size["longest_edge"]
52+
53+
def __call__(
54+
self,
55+
images=None,
56+
input_points=None,
57+
input_labels=None,
58+
input_boxes=None,
59+
return_tensors: Optional[Union[str, TensorType]] = None,
60+
**kwargs,
61+
) -> BatchEncoding:
62+
"""
63+
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
64+
points and bounding boxes for the model if they are provided.
65+
"""
66+
encoding_image_processor = self.image_processor(
67+
images,
68+
return_tensors=return_tensors,
69+
**kwargs,
70+
)
71+
72+
# pop arguments that are not used in the foward but used nevertheless
73+
original_sizes = encoding_image_processor["original_sizes"]
74+
75+
if isinstance(original_sizes, torch.Tensor):
76+
original_sizes = original_sizes.numpy()
77+
78+
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
79+
input_points=input_points,
80+
input_labels=input_labels,
81+
input_boxes=input_boxes,
82+
)
83+
84+
encoding_image_processor = self._normalize_and_convert(
85+
encoding_image_processor,
86+
original_sizes,
87+
input_points=input_points,
88+
input_labels=input_labels,
89+
input_boxes=input_boxes,
90+
return_tensors=return_tensors,
91+
)
92+
93+
return encoding_image_processor
94+
95+
def _normalize_and_convert(
96+
self,
97+
encoding_image_processor,
98+
original_sizes,
99+
input_points=None,
100+
input_labels=None,
101+
input_boxes=None,
102+
return_tensors="pt",
103+
):
104+
if input_points is not None:
105+
if len(original_sizes) != len(input_points):
106+
input_points = [
107+
self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points
108+
]
109+
else:
110+
input_points = [
111+
self._normalize_coordinates(self.target_size, point, original_size)
112+
for point, original_size in zip(input_points, original_sizes)
113+
]
114+
# check that all arrays have the same shape
115+
if not all([point.shape == input_points[0].shape for point in input_points]):
116+
if input_labels is not None:
117+
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
118+
119+
input_points = np.array(input_points)
120+
121+
if input_labels is not None:
122+
input_labels = np.array(input_labels)
123+
124+
if input_boxes is not None:
125+
if len(original_sizes) != len(input_boxes):
126+
input_boxes = [
127+
self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True)
128+
for box in input_boxes
129+
]
130+
else:
131+
input_boxes = [
132+
self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True)
133+
for box, original_size in zip(input_boxes, original_sizes)
134+
]
135+
input_boxes = np.array(input_boxes)
136+
137+
if input_boxes is not None:
138+
if return_tensors == "pt":
139+
input_boxes = torch.from_numpy(input_boxes)
140+
# boxes batch size of 1 by default
141+
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
142+
encoding_image_processor.update({"input_boxes": input_boxes})
143+
if input_points is not None:
144+
if return_tensors == "pt":
145+
input_points = torch.from_numpy(input_points)
146+
# point batch size of 1 by default
147+
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
148+
encoding_image_processor.update({"input_points": input_points})
149+
if input_labels is not None:
150+
if return_tensors == "pt":
151+
input_labels = torch.from_numpy(input_labels)
152+
# point batch size of 1 by default
153+
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
154+
encoding_image_processor.update({"input_labels": input_labels})
155+
156+
return encoding_image_processor
157+
158+
def _pad_points_and_labels(self, input_points, input_labels):
159+
r"""
160+
The method pads the 2D points and labels to the maximum number of points in the batch.
161+
"""
162+
expected_nb_points = max([point.shape[0] for point in input_points])
163+
processed_input_points = []
164+
for i, point in enumerate(input_points):
165+
if point.shape[0] != expected_nb_points:
166+
point = np.concatenate(
167+
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
168+
)
169+
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
170+
processed_input_points.append(point)
171+
input_points = processed_input_points
172+
return input_points, input_labels
173+
174+
def _normalize_coordinates(
175+
self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
176+
) -> np.ndarray:
177+
"""
178+
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
179+
"""
180+
old_h, old_w = original_size
181+
new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
182+
coords = deepcopy(coords).astype(float)
183+
184+
if is_bounding_box:
185+
coords = coords.reshape(-1, 2, 2)
186+
187+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
188+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
189+
190+
if is_bounding_box:
191+
coords = coords.reshape(-1, 4)
192+
193+
return coords
194+
195+
def _check_and_preprocess_points(
196+
self,
197+
input_points=None,
198+
input_labels=None,
199+
input_boxes=None,
200+
):
201+
r"""
202+
Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
203+
are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
204+
it is converted to a `numpy.ndarray` and then to a `list`.
205+
"""
206+
if input_points is not None:
207+
if isinstance(input_points, torch.Tensor):
208+
input_points = input_points.numpy().tolist()
209+
210+
if not isinstance(input_points, list) and not isinstance(input_points[0], list):
211+
raise ValueError("Input points must be a list of list of floating integers.")
212+
input_points = [np.array(input_point) for input_point in input_points]
213+
else:
214+
input_points = None
215+
216+
if input_labels is not None:
217+
if isinstance(input_labels, torch.Tensor):
218+
input_labels = input_labels.numpy().tolist()
219+
220+
if not isinstance(input_labels, list) and not isinstance(input_labels[0], list):
221+
raise ValueError("Input labels must be a list of list integers.")
222+
input_labels = [np.array(label) for label in input_labels]
223+
else:
224+
input_labels = None
225+
226+
if input_boxes is not None:
227+
if isinstance(input_boxes, torch.Tensor):
228+
input_boxes = input_boxes.numpy().tolist()
229+
230+
if (
231+
not isinstance(input_boxes, list)
232+
and not isinstance(input_boxes[0], list)
233+
and not isinstance(input_boxes[0][0], list)
234+
):
235+
raise ValueError("Input boxes must be a list of list of list of floating integers.")
236+
input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
237+
else:
238+
input_boxes = None
239+
240+
return input_points, input_labels, input_boxes
241+
242+
@property
243+
def model_input_names(self):
244+
image_processor_input_names = self.image_processor.model_input_names
245+
return list(dict.fromkeys(image_processor_input_names))
246+
247+
def post_process_masks(self, *args, **kwargs):
248+
return self.image_processor.post_process_masks(*args, **kwargs)

0 commit comments

Comments
 (0)