From d112bc5b6dfd75e9a23ea1836b7bd714f0d62799 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 23 May 2023 10:03:10 +0000 Subject: [PATCH 1/2] add a dummy pipeline test --- src/transformers/models/sam/image_processing_sam.py | 2 +- tests/models/sam/test_modeling_sam.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 64f3bae22218..821b43624d07 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -934,7 +934,7 @@ def _generate_crop_boxes( cropped_images, point_grid_per_crop = _generate_crop_images( crop_boxes, image, points_grid, layer_idxs, target_size, original_size ) - + crop_boxes = np.array(crop_boxes) crop_boxes = crop_boxes.astype(np.float32) points_per_crop = np.array([point_grid_per_crop]) points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 599ed5e384bc..f9dfcb0bf60f 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -20,7 +20,7 @@ import requests -from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig +from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_available, is_vision_available @@ -751,3 +751,9 @@ def test_inference_mask_generation_three_boxes_point_batch(self): iou_scores = outputs.iou_scores.cpu() self.assertTrue(iou_scores.shape == (1, 3, 3)) torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + + def test_simple_pipeline(self): + generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64) From ab5244a62171183016c65e861169f3c681e280a6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 23 May 2023 10:36:37 +0000 Subject: [PATCH 2/2] change test name --- tests/models/sam/test_modeling_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index f9dfcb0bf60f..d678fd324018 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -752,7 +752,7 @@ def test_inference_mask_generation_three_boxes_point_batch(self): self.assertTrue(iou_scores.shape == (1, 3, 3)) torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) - def test_simple_pipeline(self): + def test_dummy_pipeline_generation(self): generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) raw_image = prepare_image()