Skip to content

Commit 3902969

Browse files
committed
Fix input_boxes shapes
1 parent ee4057f commit 3902969

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/transformers/models/sam/processing_sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _check_and_preprocess_points(
223223
input_points = input_points.numpy().tolist()
224224

225225
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
226-
raise ValueError("Input points must be a list of list of floating integers.")
226+
raise ValueError("Input points must be a list of list of floating points.")
227227
input_points = [np.array(input_point) for input_point in input_points]
228228
else:
229229
input_points = None
@@ -247,7 +247,7 @@ def _check_and_preprocess_points(
247247
or not isinstance(input_boxes[0], list)
248248
or not isinstance(input_boxes[0][0], list)
249249
):
250-
raise ValueError("Input boxes must be a list of list of list of floating integers.")
250+
raise ValueError("Input boxes must be a list of list of list of floating points.")
251251
input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
252252
else:
253253
input_boxes = None

tests/models/sam/test_modeling_tf_sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_inference_mask_generation_one_point_one_bb(self):
448448
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
449449

450450
raw_image = prepare_image()
451-
input_boxes = [[650, 900, 1000, 1250]]
451+
input_boxes = [[[650, 900, 1000, 1250]]]
452452
input_points = [[[820, 1080]]]
453453

454454
inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf")
@@ -501,7 +501,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self):
501501
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
502502

503503
raw_image = prepare_image()
504-
input_boxes = [[620, 900, 1000, 1255]]
504+
input_boxes = [[[620, 900, 1000, 1255]]]
505505
input_points = [[[820, 1080]]]
506506
labels = [[0]]
507507

0 commit comments

Comments
 (0)