From cd8df3ad3b2685c84e1aa28c3a246a1c45309aa2 Mon Sep 17 00:00:00 2001 From: Matt Date: Sat, 22 Apr 2023 13:05:04 +0100 Subject: [PATCH 01/49] First commit --- .../models/sam/image_processing_sam.py | 303 +++++++++++++++++- 1 file changed, 302 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 53852961041b..e4bf74f411bc 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -34,7 +34,7 @@ to_numpy_array, valid_images, ) -from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends +from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends, is_tf_available if is_torch_available(): @@ -44,6 +44,10 @@ if is_torchvision_available(): from torchvision.ops.boxes import batched_nms +if is_tf_available(): + import tensorflow as tf + from tensorflow.experimental import numpy as tnp + logger = logging.get_logger(__name__) @@ -418,6 +422,43 @@ def post_process_masks( return output_masks + def post_process_masks_tf(self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`tf.Tensor`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`tf.Tensor`): + The original size of the images before resizing for input to the model, in (height, width) format. + reshaped_input_sizes (`tf.Tensor`): + The size of the image input to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + requires_backends(self, ["tensorflow"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + + output_masks = [] + for i, original_size in enumerate(original_sizes): + interpolated_mask = tf.image.resize(masks[i], target_image_size, method="bilinear") + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): """ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. @@ -434,6 +475,22 @@ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, cro """ return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + def post_process_for_mask_generation_tf(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`List[tf.Tensor]`): + List of all predicted segmentation masks + all_scores (`List[tf.Tensor]`): + List of all predicted iou scores + all_boxes (`List[tf.Tensor]`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) + def generate_crop_boxes( self, image, @@ -469,6 +526,38 @@ def generate_crop_boxes( image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device ) + def generate_crop_boxes_tf( + self, + image, + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`np.array`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + """ + return _generate_crop_boxes_tf( + image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor + ) + def filter_masks( self, masks, @@ -549,6 +638,83 @@ def filter_masks( return masks, scores, converted_boxes + def filter_masks_tf( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`tf.Tensor`): + Input masks. + iou_scores (`tf.Tensor`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["tensorflow"]) + original_height, original_width = original_size + iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) + masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + batch_size = masks.shape[0] + + keep_mask = tf.ones(batch_size, dtype=tf.bool) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box_tf(masks) + + keep_mask = ~_is_box_near_crop_edge_tf( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppresion + masks = _mask_to_rle_tf(masks) + + return masks, scores, converted_boxes + def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): # One mask is always contained inside the other. @@ -560,6 +726,14 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi stability_scores = intersections / unions return stability_scores +def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): + # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure + # we get the right division results. + intersections = tf.count_nonzero(masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32) + unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) + stability_scores = intersections / unions + return stability_scores + def _build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" @@ -657,6 +831,59 @@ def _generate_crop_boxes( return crop_boxes, points_per_crop, cropped_images, input_labels +def _generate_crop_boxes_tf( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `tf.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + image = to_numpy_array(image) + original_size = get_image_size(image) + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + + crop_boxes = tf.convert_to_tensor(crop_boxes, dtype=tf.float32) + point_grid_per_crop = np.array([point_grid_per_crop]) + points_per_crop = tf.convert_to_tensor(point_grid_per_crop) + points_per_crop = tf.transpose(points_per_crop, perm=(0, 2, 1, 3)) + + input_labels = tf.ones_like(points_per_crop[:, :, :, 0], dtype=tf.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): """ @@ -729,6 +956,14 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int): pad = (left, pad_x - left, top, pad_y - top) return torch.nn.functional.pad(masks, pad, value=0) +def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return tf.pad(masks, pad, constant_values=0) def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): """Filter masks at the edge of a crop, but not at the edge of the original image.""" @@ -747,6 +982,23 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) return torch.any(near_crop_edge, dim=1) +def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) + orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32) + + left, top, _, _ = crop_box + offset = tf.convert_to_tensor([[left, top, left, top]]) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = tf.expand_dims(offset, 1) + boxes = tf.cast(boxes + offset, tf.float32) + + near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0) + near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0) + near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge) + return tf.reduce_any(near_crop_edge, axis=1) + def _batched_mask_to_box(masks: "torch.Tensor"): """ @@ -796,6 +1048,55 @@ def _batched_mask_to_box(masks: "torch.Tensor"): out = out.reshape(*shape[:-2], 4) return out +# TODO CONTINUE FROM HERE +def _batched_mask_to_box_tf(masks: "tf.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if tf.size(masks) == 0: + return tf.zeros([*masks.shape[:-2], 4]) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height = tf.reduce_max(masks, axis=-1) + in_height_coords = in_height * tf.range(height)[None, :] + bottom_edges = tf.reduce_max(in_height_coords, axis=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges = tf.reduce_min(in_height_coords, axis=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): """ From eb103df9e11e6c846ed39f00b7f987d09ef7d33d Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 24 Apr 2023 18:35:02 +0100 Subject: [PATCH 02/49] Add auto-translation with GPT-4 --- .../models/sam/modeling_tf_sam.py | 1288 +++++++++++++++++ 1 file changed, 1288 insertions(+) create mode 100644 src/transformers/models/sam/modeling_tf_sam.py diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py new file mode 100644 index 000000000000..5ae000279e71 --- /dev/null +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -0,0 +1,1288 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original with GPT-4. +In the event of a discrepancy, the original file should be regarded as the 'reference' version. +""" + +import collections +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import TFPreTrainedModel, shape_list +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + +SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/sam-vit-huge", + "facebook/sam-vit-large", + "facebook/sam-vit-base", + # See all SAM models at https://huggingface.co/models?filter=sam +] +@dataclass +class TFSamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[tf.Tensor] = None + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFSamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: tf.Tensor = None + pred_masks: tf.Tensor = None + vision_hidden_states: Optional[Tuple[tf.Tensor]] = None + vision_attentions: Optional[Tuple[tf.Tensor]] = None + mask_decoder_attentions: Optional[Tuple[tf.Tensor]] = None + + +class TFSamPatchEmbeddings(tf.keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = tf.keras.layers.Conv2D(hidden_size, kernel_size=patch_size, strides=patch_size, name="projection") + + def call(self, pixel_values): + batch_size, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = tf.transpose(self.projection(pixel_values), perm=[0, 2, 3, 1]) + return embeddings +class TFSamMLPBlock(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1") + self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2") + self.act = ACT2FN[config.hidden_act] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class TFSamLayerNorm(tf.keras.layers.Layer): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(**kwargs) + self.weight = self.add_weight(shape=normalized_shape, initializer="ones", name="weight") + self.bias = self.add_weight(shape=normalized_shape, initializer="zeros", name="bias") + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def call(self, x: tf.Tensor) -> tf.Tensor: + if self.data_format == "channels_last": + x = tf.keras.layers.LayerNormalization(axis=self.normalized_shape, epsilon=self.eps)(x) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = tf.cast(x, tf.float32) + u = tf.reduce_mean(x, axis=1, keepdims=True) + s = tf.math.square(x - u) + s = tf.reduce_mean(s, axis=1, keepdims=True) + x = (x - u) / tf.math.sqrt(s + self.eps) + x = tf.cast(x, input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TFSamAttention(tf.keras.layers.Layer): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None, **kwargs) -> None: + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = tf.keras.layers.Dense(self.internal_dim, name="q_proj") + self.k_proj = tf.keras.layers.Dense(self.internal_dim, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.internal_dim, name="v_proj") + self.out_proj = tf.keras.layers.Dense(self.hidden_size, name="out_proj") + + def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: + batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) + c_per_head = channel // num_attention_heads + hidden_states = tf.reshape(hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)) + return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + + def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: + batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + return tf.reshape( + hidden_states, + (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) + ) + + def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = shape_list(query)[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = shape_list(query) + attn = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / tf.math.sqrt(float(c_per_head)) + attn = tf.nn.softmax(attn, axis=-1) + + # Get output + out = tf.matmul(attn, value) + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out +class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): + def __init__( + self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs + ) -> None: + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") + + self.cross_attn_token_to_image = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" + ) + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") + + self.mlp = TFSamMLPBlock(config, name="mlp") + self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") + + self.layer_norm4 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") + self.cross_attn_image_to_token = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def call( + self, + queries: tf.Tensor, + keys: tf.Tensor, + query_point_embedding: tf.Tensor, + key_point_embedding: tf.Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs +class TFSamTwoWayTransformer(tf.keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = [] + + for i in range(self.num_hidden_layers): + self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_{i}")) + + self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") + self.layer_norm_final_attn = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_final_attn") + + def call( + self, + point_embeddings: tf.Tensor, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = tf.expand_dims(tf.transpose(tf.reshape(image_embeddings, (shape_list(image_embeddings)[0], -1, shape_list(image_embeddings)[-1])), (0, 2, 1)), 1) + image_positional_embeddings = tf.expand_dims(tf.transpose(tf.reshape(image_positional_embeddings, (shape_list(image_positional_embeddings)[0], -1, shape_list(image_positional_embeddings)[-1])), (0, 2, 1)), 1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions +class TFSamFeedForward(tf.keras.layers.Layer): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.activation = tf.keras.layers.ReLU() + self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") + self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") + self.layers = [tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layer_{i}") for i in range(num_layers - 2)] + self.sigmoid_output = sigmoid_output + + def call(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = tf.sigmoid(hidden_states) + return hidden_states +class TFSamMaskDecoder(tf.keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = tf.keras.layers.Embedding(1, self.hidden_size, name="iou_token") + self.mask_tokens = tf.keras.layers.Embedding(self.num_mask_tokens, self.hidden_size, name="mask_tokens") + + self.transformer = TFSamTwoWayTransformer(config) + + self.upscale_conv1 = tf.keras.layers.Conv2DTranspose(self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1") + self.upscale_conv2 = tf.keras.layers.Conv2DTranspose(self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2") + self.upscale_layer_norm = TFSamLayerNorm(self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm") + self.activation = tf.keras.layers.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [TFSamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = tf.keras.layers.LayerList(mlps_list) + + self.iou_prediction_head = TFSamFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + def call( + self, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + sparse_prompt_embeddings: tf.Tensor, + dense_prompt_embeddings: tf.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + batch_size, num_channels, height, width = shape_list(image_embeddings) + point_batch_size = max(1, shape_list(sparse_prompt_embeddings)[1]) + output_tokens = tf.concat([self.iou_token.embeddings, self.mask_tokens.embeddings], axis=0) + output_tokens = tf.tile(tf.expand_dims(output_tokens, axis=0), [batch_size, point_batch_size, 1, 1]) + + if tf.reduce_sum(sparse_prompt_embeddings) != 0: + tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) + else: + tokens = output_tokens + point_embeddings = tf.cast(tokens, self.iou_token.embeddings.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) + image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = tf.stack(hyper_in_list, axis=2) + + _, num_channels, height, width = shape_list(upscaled_embedding) + upscaled_embedding = tf.reshape(upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]) + masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs +class TFSamPositionalEmbedding(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.scale = config.hidden_size // 2 + self.positional_embedding = self.add_weight( + name="positional_embedding", + shape=(2, config.num_pos_feats), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), + trainable=False, + ) + + def call(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = tf.identity(input_coords) + + if input_shape is not None: + coordinates = tf.tensor_scatter_nd_update( + coordinates, + indices=[[0, 0]], + updates=coordinates[:, :, :, 0] / input_shape[1], + ) + coordinates = tf.tensor_scatter_nd_update( + coordinates, + indices=[[0, 1]], + updates=coordinates[:, :, :, 1] / input_shape[0], + ) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = tf.cast(coordinates, self.positional_embedding.dtype) + coordinates = tf.linalg.matmul(coordinates, self.positional_embedding) + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) +class TFSamMaskEmbedding(tf.keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, **kwargs): + super().__init__(**kwargs) + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") + self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") + self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") + self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") + + def call(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings +class TFSamPromptEncoder(tf.keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): + super().__init__(**kwargs) + self.shared_embedding = shared_patch_embedding + self.mask_embed = TFSamMaskEmbedding(config) + self.no_mask_embed = tf.keras.layers.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = [ + tf.keras.layers.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings) + ] + self.hidden_size = config.hidden_size + self.not_a_point_embed = tf.keras.layers.Embedding(1, config.hidden_size) + + def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = tf.zeros(target_point_shape, dtype=points.dtype) + padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) + points = tf.concat([points, padding_point], axis=2) + labels = tf.concat([labels, padding_label], axis=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed.weights[0], point_embedding) + + point_embedding = tf.where( + labels[..., None] != -10, + point_embedding, + tf.zeros_like(point_embedding), + ) + + point_embedding = tf.where(labels == 0, point_embedding + self.point_embed[0].weights[0], point_embedding) + point_embedding = tf.where(labels == 1, point_embedding + self.point_embed[1].weights[0], point_embedding) + return point_embedding + + def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding += tf.where(tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, self.point_embed[2].weights[0], self.point_embed[3].weights[0]) + return corner_embedding + + def call( + self, + input_points: Optional[Tuple[tf.Tensor, tf.Tensor]], + input_labels: Optional[tf.Tensor], + input_boxes: Optional[tf.Tensor], + input_masks: Optional[tf.Tensor], + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`tf.Tensor`, **optionnal**): + point coordinates and labels to embed. + boxes (`tf.Tensor`, **optionnal**): + boxes to embed + masks (`tf.Tensor`, **optionnal**): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = tf.zeros((batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype) + sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = tf.reshape(self.no_mask_embed.weights[0], (1, -1, 1, 1)) + dense_embeddings = tf.tile(dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])) + + if sparse_embeddings is None: + sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) + + return sparse_embeddings, dense_embeddings +class TFSamVisionAttention(tf.keras.layers.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size, **kwargs) -> None: + super().__init__(**kwargs) + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = tf.keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") + self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj") + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = self.add_weight(shape=(2 * input_size[0] - 1, head_dim), initializer="zeros", name="rel_pos_h") + self.rel_pos_w = self.add_weight(shape=(2 * input_size[1] - 1, head_dim), initializer="zeros", name="rel_pos_w") + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`tf.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = tf.image.resize( + tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), + size=(max_rel_dist, rel_pos.shape[1]), + method="bilinear", + ) + rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) + k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) + + def add_decomposed_rel_pos( + self, + attn: tf.Tensor, + query: tf.Tensor, + rel_pos_h: tf.Tensor, + rel_pos_w: tf.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> tf.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`tf.Tensor`): + attention map. + query (`tf.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`tf.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`tf.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`tf.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) + rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width)) + attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2) + attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width)) + return attn + + def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = tf.unstack(tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0) + + attn_weights = (query * self.scale) @ tf.transpose(key, perm=(-2, -1)) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = tf.reshape(attn_output, (batch_size, height, width, -1)) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs +class TFSamVisionLayer(tf.keras.layers.Layer): + def __init__(self, config, window_size, **kwargs) -> None: + super().__init__(**kwargs) + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.attn = TFSamVisionAttention(config, window_size, name="attn") + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + self.mlp = TFSamMLPBlock(config, name="mlp") + self.window_size = window_size + + def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]: + batch_size, height, width, channel = shape_list(hidden_states) + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + if pad_h > 0 or pad_w > 0: + hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = tf.reshape( + hidden_states, + [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel] + ) + windows = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), + [-1, window_size, window_size, channel] + ) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> tf.Tensor: + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = tf.reshape( + windows, + [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] + ) + hidden_states = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), + [batch_size, pad_height, pad_width, -1] + ) + + if pad_height > height or pad_width > width: + hidden_states = hidden_states[:, :height, :width, :] + return hidden_states + + def call( + self, + hidden_states: tf.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[tf.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs +class TFSamVisionNeck(tf.keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.conv1 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=1, use_bias=False, name="conv1") + self.layer_norm1 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm1") + self.conv2 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2") + self.layer_norm2 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm2") + + def call(self, hidden_states): + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states +class TFSamVisionEncoder(tf.keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.image_size = config.image_size + + self.patch_embed = TFSamPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = tf.Variable( + tf.zeros( + [ + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ] + ), + trainable=True, + ) + + self.layers = [] + for i in range(config.num_hidden_layers): + layer = TFSamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = TFSamVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFSamVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return TFSamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) +class TFSamPreTrainedModel(TFPreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (tf.keras.layers.Dense, tf.keras.layers.Conv2D, tf.keras.layers.Conv2DTranspose)): + module.kernel.assign(tf.random.normal(module.kernel.shape, mean=0.0, stddev=std)) + if module.bias is not None: + module.bias.assign(tf.zeros(module.bias.shape)) + elif isinstance(module, tf.keras.layers.Embedding): + module.embeddings.assign(tf.random.normal(module.embeddings.shape, mean=0.0, stddev=std)) + if module.padding_idx is not None: + module.embeddings[module.padding_idx].assign(tf.zeros(module.embeddings[module.padding_idx].shape)) + + +SAM_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. + Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `call` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class TFSamModel(TFSamPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config, **kwargs) -> None: + super().__init__(config, **kwargs) + self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config) + + self.vision_encoder = TFSamVisionEncoder(config.vision_config) + self.prompt_encoder = TFSamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding) + self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + grid = tf.ones((size, size)) + y_embed = tf.math.cumsum(grid, axis=0) - 0.5 + x_embed = tf.math.cumsum(grid, axis=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) + return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: Optional[tf.Tensor] = None, + input_labels: Optional[tf.Tensor] = None, + input_boxes: Optional[tf.Tensor] = None, + input_masks: Optional[tf.Tensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + input_points: Optional[tf.Tensor] = None, + input_labels: Optional[tf.Tensor] = None, + input_boxes: Optional[tf.Tensor] = None, + input_masks: Optional[tf.Tensor] = None, + image_embeddings: Optional[tf.Tensor] = None, + multimask_output: bool = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict=None, + **kwargs, + ) -> List[Dict[str, tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) + + vision_attentions = None + mask_decoder_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + if output_hidden_states: + vision_hidden_states = vision_outputs[1] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return TFSamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) \ No newline at end of file From 0a611f1fe9ac1aca4b01d9c9a2e33d5888060a2f Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 24 Apr 2023 18:39:20 +0100 Subject: [PATCH 03/49] make fixup --- .../models/sam/modeling_tf_sam.py | 144 ++++++++++++++---- 1 file changed, 111 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 5ae000279e71..1c158f536a82 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -18,7 +18,6 @@ """ import collections -import math from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union @@ -43,6 +42,8 @@ "facebook/sam-vit-base", # See all SAM models at https://huggingface.co/models?filter=sam ] + + @dataclass class TFSamVisionEncoderOutput(ModelOutput): """ @@ -71,6 +72,8 @@ class TFSamVisionEncoderOutput(ModelOutput): last_hidden_state: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFSamImageSegmentationOutput(ModelOutput): """ @@ -126,7 +129,9 @@ def __init__(self, config, **kwargs): self.num_channels = num_channels self.num_patches = num_patches - self.projection = tf.keras.layers.Conv2D(hidden_size, kernel_size=patch_size, strides=patch_size, name="projection") + self.projection = tf.keras.layers.Conv2D( + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" + ) def call(self, pixel_values): batch_size, num_channels, height, width = shape_list(pixel_values) @@ -140,6 +145,8 @@ def call(self, pixel_values): ) embeddings = tf.transpose(self.projection(pixel_values), perm=[0, 2, 3, 1]) return embeddings + + class TFSamMLPBlock(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) @@ -183,6 +190,8 @@ def call(self, x: tf.Tensor) -> tf.Tensor: x = tf.cast(x, input_dtype) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x + + class TFSamAttention(tf.keras.layers.Layer): """ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and @@ -208,15 +217,16 @@ def __init__(self, config, downsample_rate=None, **kwargs) -> None: def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) c_per_head = channel // num_attention_heads - hidden_states = tf.reshape(hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)) + hidden_states = tf.reshape( + hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + ) return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) return tf.reshape( - hidden_states, - (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) + hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) ) def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: @@ -233,7 +243,9 @@ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: # SamAttention _, _, _, c_per_head = shape_list(query) - attn = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = tf.matmul( + query, tf.transpose(key, perm=[0, 1, 3, 2]) + ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens attn = attn / tf.math.sqrt(float(c_per_head)) attn = tf.nn.softmax(attn, axis=-1) @@ -243,6 +255,8 @@ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: out = self.out_proj(out) return out + + class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): def __init__( self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs @@ -331,6 +345,8 @@ def call( outputs = outputs + (None,) return outputs + + class TFSamTwoWayTransformer(tf.keras.layers.Layer): def __init__(self, config: SamMaskDecoderConfig, **kwargs): super().__init__(**kwargs) @@ -343,7 +359,9 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_{i}")) self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") - self.layer_norm_final_attn = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_final_attn") + self.layer_norm_final_attn = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layer_norm_final_attn" + ) def call( self, @@ -365,8 +383,23 @@ def call( if image_embeddings is None: raise ValueError("You have to specify an image_embedding") - image_embeddings = tf.expand_dims(tf.transpose(tf.reshape(image_embeddings, (shape_list(image_embeddings)[0], -1, shape_list(image_embeddings)[-1])), (0, 2, 1)), 1) - image_positional_embeddings = tf.expand_dims(tf.transpose(tf.reshape(image_positional_embeddings, (shape_list(image_positional_embeddings)[0], -1, shape_list(image_positional_embeddings)[-1])), (0, 2, 1)), 1) + image_embeddings = tf.expand_dims( + tf.transpose( + tf.reshape(image_embeddings, (shape_list(image_embeddings)[0], -1, shape_list(image_embeddings)[-1])), + (0, 2, 1), + ), + 1, + ) + image_positional_embeddings = tf.expand_dims( + tf.transpose( + tf.reshape( + image_positional_embeddings, + (shape_list(image_positional_embeddings)[0], -1, shape_list(image_positional_embeddings)[-1]), + ), + (0, 2, 1), + ), + 1, + ) # Prepare queries queries = point_embeddings @@ -394,6 +427,8 @@ def call( queries = queries + attn_out queries = self.layer_norm_final_attn(queries) return queries, keys, all_attentions + + class TFSamFeedForward(tf.keras.layers.Layer): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs @@ -403,7 +438,10 @@ def __init__( self.activation = tf.keras.layers.ReLU() self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") - self.layers = [tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layer_{i}") for i in range(num_layers - 2)] + self.layers = [ + tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layer_{i}") + for i in range(num_layers - 2) + ] self.sigmoid_output = sigmoid_output def call(self, hidden_states): @@ -416,6 +454,8 @@ def call(self, hidden_states): if self.sigmoid_output: hidden_states = tf.sigmoid(hidden_states) return hidden_states + + class TFSamMaskDecoder(tf.keras.layers.Layer): def __init__(self, config: SamMaskDecoderConfig, **kwargs): super().__init__(**kwargs) @@ -430,9 +470,15 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.transformer = TFSamTwoWayTransformer(config) - self.upscale_conv1 = tf.keras.layers.Conv2DTranspose(self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1") - self.upscale_conv2 = tf.keras.layers.Conv2DTranspose(self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2") - self.upscale_layer_norm = TFSamLayerNorm(self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm") + self.upscale_conv1 = tf.keras.layers.Conv2DTranspose( + self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1" + ) + self.upscale_conv2 = tf.keras.layers.Conv2DTranspose( + self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2" + ) + self.upscale_layer_norm = TFSamLayerNorm( + self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" + ) self.activation = tf.keras.layers.GELU() mlps_list = [] @@ -490,7 +536,9 @@ def call( hyper_in = tf.stack(hyper_in_list, axis=2) _, num_channels, height, width = shape_list(upscaled_embedding) - upscaled_embedding = tf.reshape(upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]) + upscaled_embedding = tf.reshape( + upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] + ) masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) iou_pred = self.iou_prediction_head(iou_token_out) @@ -510,6 +558,8 @@ def call( outputs = outputs + (None,) return outputs + + class TFSamPositionalEmbedding(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) @@ -544,6 +594,8 @@ def call(self, input_coords, input_shape=None): coordinates = 2 * np.pi * coordinates # outputs d_1 x ... x d_n x channel shape return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) + + class TFSamMaskEmbedding(tf.keras.layers.Layer): def __init__(self, config: SamPromptEncoderConfig, **kwargs): super().__init__(**kwargs) @@ -565,6 +617,8 @@ def call(self, masks): hidden_states = self.activation(hidden_states) dense_embeddings = self.conv3(hidden_states) return dense_embeddings + + class TFSamPromptEncoder(tf.keras.layers.Layer): def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): super().__init__(**kwargs) @@ -613,7 +667,11 @@ def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) input_shape = (self.input_image_size, self.input_image_size) corner_embedding = self.shared_embedding(coords, input_shape) - corner_embedding += tf.where(tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, self.point_embed[2].weights[0], self.point_embed[3].weights[0]) + corner_embedding += tf.where( + tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, + self.point_embed[2].weights[0], + self.point_embed[3].weights[0], + ) return corner_embedding def call( @@ -641,7 +699,9 @@ def call( if input_labels is None: raise ValueError("If points are provided, labels must also be provided.") point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) - sparse_embeddings = tf.zeros((batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype) + sparse_embeddings = tf.zeros( + (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype + ) sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) if input_boxes is not None: batch_size = input_boxes.shape[0] @@ -654,12 +714,16 @@ def call( dense_embeddings = self.mask_embed(input_masks) else: dense_embeddings = tf.reshape(self.no_mask_embed.weights[0], (1, -1, 1, 1)) - dense_embeddings = tf.tile(dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])) + dense_embeddings = tf.tile( + dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) if sparse_embeddings is None: sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) return sparse_embeddings, dense_embeddings + + class TFSamVisionAttention(tf.keras.layers.Layer): """Multi-head Attention block with relative position embeddings.""" @@ -685,8 +749,12 @@ def __init__(self, config, window_size, **kwargs) -> None: raise ValueError("Input size must be provided if using relative positional encoding.") # initialize relative positional embeddings - self.rel_pos_h = self.add_weight(shape=(2 * input_size[0] - 1, head_dim), initializer="zeros", name="rel_pos_h") - self.rel_pos_w = self.add_weight(shape=(2 * input_size[1] - 1, head_dim), initializer="zeros", name="rel_pos_w") + self.rel_pos_h = self.add_weight( + shape=(2 * input_size[0] - 1, head_dim), initializer="zeros", name="rel_pos_h" + ) + self.rel_pos_w = self.add_weight( + shape=(2 * input_size[1] - 1, head_dim), initializer="zeros", name="rel_pos_w" + ) def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: """ @@ -772,12 +840,13 @@ def add_decomposed_rel_pos( def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = ( - tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) - .permute(2, 0, 3, 1, 4) - ) + qkv = tf.reshape( + self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1) + ).permute(2, 0, 3, 1, 4) # q, k, v with shape (batch_size * nHead, height * width, channel) - query, key, value = tf.unstack(tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0) + query, key, value = tf.unstack( + tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 + ) attn_weights = (query * self.scale) @ tf.transpose(key, perm=(-2, -1)) @@ -801,6 +870,8 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: outputs = (attn_output, None) return outputs + + class TFSamVisionLayer(tf.keras.layers.Layer): def __init__(self, config, window_size, **kwargs) -> None: super().__init__(**kwargs) @@ -821,11 +892,10 @@ def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[ hidden_states = tf.reshape( hidden_states, - [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel] + [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], ) windows = tf.reshape( - tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), - [-1, window_size, window_size, channel] + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] ) return windows, (pad_height, pad_width) @@ -836,12 +906,10 @@ def window_unpartition( height, width = original_shape batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) hidden_states = tf.reshape( - windows, - [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] + windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] ) hidden_states = tf.reshape( - tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), - [batch_size, pad_height, pad_width, -1] + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] ) if pad_height > height or pad_width > width: @@ -876,6 +944,8 @@ def call( outputs += (attn_weights,) return outputs + + class TFSamVisionNeck(tf.keras.layers.Layer): def __init__(self, config: SamVisionConfig, **kwargs): super().__init__(**kwargs) @@ -883,7 +953,9 @@ def __init__(self, config: SamVisionConfig, **kwargs): self.conv1 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=1, use_bias=False, name="conv1") self.layer_norm1 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm1") - self.conv2 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2") + self.conv2 = tf.keras.layers.Conv2D( + config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2" + ) self.layer_norm2 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm2") def call(self, hidden_states): @@ -894,6 +966,8 @@ def call(self, hidden_states): hidden_states = self.conv2(hidden_states) hidden_states = self.layer_norm2(hidden_states) return hidden_states + + class TFSamVisionEncoder(tf.keras.layers.Layer): def __init__(self, config: SamVisionConfig, **kwargs): super().__init__(**kwargs) @@ -984,6 +1058,8 @@ def call( hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + class TFSamPreTrainedModel(TFPreTrainedModel): config_class = SamConfig base_model_prefix = "sam" @@ -1079,6 +1155,8 @@ def _init_weights(self, module): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ + + @add_start_docstrings( "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", " optional 2D location and bounding boxes.", @@ -1285,4 +1363,4 @@ def call( vision_hidden_states=vision_hidden_states, vision_attentions=vision_attentions, mask_decoder_attentions=mask_decoder_attentions, - ) \ No newline at end of file + ) From 5067b6036fda45244994ae61ee18e31247b48894 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 24 Apr 2023 20:07:56 +0100 Subject: [PATCH 04/49] Add a functional layernorm for TF --- .../models/sam/modeling_tf_sam.py | 7 +++--- src/transformers/tf_utils.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 1c158f536a82..cfe8011616d1 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -28,6 +28,7 @@ from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_utils import TFPreTrainedModel, shape_list from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...tf_utils import functional_layernorm from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -179,7 +180,7 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kw def call(self, x: tf.Tensor) -> tf.Tensor: if self.data_format == "channels_last": - x = tf.keras.layers.LayerNormalization(axis=self.normalized_shape, epsilon=self.eps)(x) + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps) elif self.data_format == "channels_first": input_dtype = x.dtype x = tf.cast(x, tf.float32) @@ -479,7 +480,7 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.upscale_layer_norm = TFSamLayerNorm( self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" ) - self.activation = tf.keras.layers.GELU() + self.activation = tf.nn.gelu mlps_list = [] for _ in range(self.num_mask_tokens): @@ -590,7 +591,7 @@ def call(self, input_coords, input_shape=None): # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coordinates = 2 * coordinates - 1 coordinates = tf.cast(coordinates, self.positional_embedding.dtype) - coordinates = tf.linalg.matmul(coordinates, self.positional_embedding) + coordinates = tf.matmul(coordinates, self.positional_embedding) coordinates = 2 * np.pi * coordinates # outputs d_1 x ... x d_n x channel shape return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 20fe71d6ae5a..ca3766b5a0cc 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -70,6 +70,28 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) +def functional_layernorm(inputs, weight, bias, epsilon=1e-5): + # Matt: This is a very simplified functional layernorm, designed to duplicate + # the functionality of PyTorch nn.functional.layer_norm when this is needed to port + # models in Transformers. It assumes the dimension to be normalized is always the last one. + # If you need it to handle multiple dimensions, yell at me and I'll patch it. + + # Calculate the moments on the last axis (layer activations). + mean, variance = tf.nn.moments(inputs, -1, keepdims=True) + + # Compute layer normalization using the batch_normalization + # function. + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset=bias, + scale=weight, + variance_epsilon=epsilon, + ) + return outputs + + def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: """ Invert an attention mask (e.g., switches 0. and 1.). From 75eb390ff9f66f03a7c97ca32236ffbe11d2639f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 25 Apr 2023 16:14:32 +0100 Subject: [PATCH 05/49] Add all the auxiliary imports etc. --- docs/source/en/model_doc/sam.mdx | 6 + src/transformers/__init__.py | 12 ++ .../models/auto/modeling_tf_auto.py | 12 ++ src/transformers/models/sam/__init__.py | 21 ++- .../models/sam/image_processing_sam.py | 139 +++++++++++++----- .../models/sam/modeling_tf_sam.py | 4 +- src/transformers/tf_utils.py | 22 ++- 7 files changed, 176 insertions(+), 40 deletions(-) diff --git a/docs/source/en/model_doc/sam.mdx b/docs/source/en/model_doc/sam.mdx index 969b7e2b2290..33dd6857d10c 100644 --- a/docs/source/en/model_doc/sam.mdx +++ b/docs/source/en/model_doc/sam.mdx @@ -99,3 +99,9 @@ Resources: [[autodoc]] SamModel - forward + + +## TFSamModel + +[[autodoc]] SamModel + - call \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ca476f30c291..f4c6921a6384 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3405,6 +3405,13 @@ "TFRoFormerPreTrainedModel", ] ) + _import_structure["models.sam"].extend( + [ + "TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSamModel", + "TFSamPreTrainedModel", + ] + ) _import_structure["models.segformer"].extend( [ "TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -6646,6 +6653,11 @@ TFRoFormerModel, TFRoFormerPreTrainedModel, ) + from .models.sam import ( + TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSamModel, + TFSamPreTrainedModel, + ) from .models.segformer import ( TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TFSegformerDecodeHead, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 5d637271f98d..3f15396c34d5 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -76,6 +76,7 @@ ("roberta", "TFRobertaModel"), ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), ("roformer", "TFRoFormerModel"), + ("sam", "TFSamModel"), ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), ("swin", "TFSwinModel"), @@ -426,6 +427,11 @@ ("mobilebert", "TFMobileBertForNextSentencePrediction"), ] ) +TFMODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("sam", "TFSamModel"), + ] +) TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) @@ -476,6 +482,12 @@ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) +TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) + + +class TFAutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING + class TFAutoModel(_BaseAutoModelClass): _model_mapping = TF_MODEL_MAPPING diff --git a/src/transformers/models/sam/__init__.py b/src/transformers/models/sam/__init__.py index d9bbf5f5eaf2..2c9c39c7f9af 100644 --- a/src/transformers/models/sam/__init__.py +++ b/src/transformers/models/sam/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available, is_tf_available _import_structure = { @@ -39,6 +39,17 @@ "SamModel", "SamPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_sam"] = [ + "TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSamModel", + "TFSamPreTrainedModel", + ] try: if not is_vision_available(): raise OptionalDependencyNotAvailable() @@ -66,6 +77,14 @@ else: from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel + try: if not is_vision_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index e4bf74f411bc..038b8fc19117 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -34,7 +34,14 @@ to_numpy_array, valid_images, ) -from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends, is_tf_available +from ...utils import ( + TensorType, + is_tf_available, + is_torch_available, + is_torchvision_available, + logging, + requires_backends, +) if is_torch_available(): @@ -47,6 +54,7 @@ if is_tf_available(): import tensorflow as tf from tensorflow.experimental import numpy as tnp + from ...tf_utils import shape_list, flatten logger = logging.get_logger(__name__) @@ -422,28 +430,30 @@ def post_process_masks( return output_masks - def post_process_masks_tf(self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None): + def post_process_masks_tf( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`tf.Tensor`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`tf.Tensor`): + The original size of the images before resizing for input to the model, in (height, width) format. + reshaped_input_sizes (`tf.Tensor`): + The size of the image input to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. """ - Remove padding and upscale masks to the original image size. - - Args: - masks (`tf.Tensor`): - Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. - original_sizes (`tf.Tensor`): - The original size of the images before resizing for input to the model, in (height, width) format. - reshaped_input_sizes (`tf.Tensor`): - The size of the image input to the model, in (height, width) format. Used to remove padding. - mask_threshold (`float`, *optional*, defaults to 0.0): - The threshold to use for binarizing the masks. - binarize (`bool`, *optional*, defaults to `True`): - Whether to binarize the masks. - pad_size (`int`, *optional*, defaults to `self.pad_size`): - The target size the images were padded to before being passed to the model. If None, the target size is - assumed to be the processor's `pad_size`. - Returns: - (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) - is given by original_size. - """ requires_backends(self, ["tensorflow"]) pad_size = self.pad_size if pad_size is None else pad_size target_image_size = (pad_size["height"], pad_size["width"]) @@ -726,10 +736,13 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi stability_scores = intersections / unions return stability_scores + def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure # we get the right division results. - intersections = tf.count_nonzero(masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32) + intersections = tf.count_nonzero( + masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32 + ) unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) stability_scores = intersections / unions return stability_scores @@ -831,6 +844,7 @@ def _generate_crop_boxes( return crop_boxes, points_per_crop, cropped_images, input_labels + def _generate_crop_boxes_tf( image, target_size: int, # Is it tuple here? @@ -956,6 +970,7 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int): pad = (left, pad_x - left, top, pad_y - top) return torch.nn.functional.pad(masks, pad, value=0) + def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int): left, top, right, bottom = crop_box if left == 0 and top == 0 and right == orig_width and bottom == orig_height: @@ -965,6 +980,7 @@ def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int) pad = (left, pad_x - left, top, pad_y - top) return tf.pad(masks, pad, constant_values=0) + def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) @@ -982,6 +998,7 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) return torch.any(near_crop_edge, dim=1) + def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) @@ -1048,7 +1065,7 @@ def _batched_mask_to_box(masks: "torch.Tensor"): out = out.reshape(*shape[:-2], 4) return out -# TODO CONTINUE FROM HERE + def _batched_mask_to_box_tf(masks: "tf.Tensor"): """ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which @@ -1062,15 +1079,14 @@ def _batched_mask_to_box_tf(masks: "tf.Tensor"): is channel_1 x channel_2 x ... x 4. Args: - - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`) """ - # torch.max below raises an error on empty inputs, just skip in this case if tf.size(masks) == 0: return tf.zeros([*masks.shape[:-2], 4]) # Normalize shape to Cxheightxwidth - shape = masks.shape + shape = shape_list(masks) height, width = shape[-2:] # Get top and bottom edges @@ -1081,20 +1097,20 @@ def _batched_mask_to_box_tf(masks: "tf.Tensor"): top_edges = tf.reduce_min(in_height_coords, axis=-1) # Get left and right edges - in_width, _ = torch.max(masks, dim=-2) - in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] - right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width, _ = tf.reduce_max(masks, axis=-2) + in_width_coords = in_width * tf.range(width)[None, :] + right_edges, _ = tf.reduce_max(in_width_coords, axis=-1) in_width_coords = in_width_coords + width * (~in_width) - left_edges, _ = torch.min(in_width_coords, dim=-1) + left_edges, _ = tf.reduce_min(in_width_coords, axis=-1) # If the mask is empty the right edge will be to the left of the left edge. # Replace these boxes with [0, 0, 0, 0] empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) - out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) - out = out * (~empty_filter).unsqueeze(-1) + out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) + out = out * tf.expand_dims(~empty_filter, -1) # Return to original shape - out = out.reshape(*shape[:-2], 4) + out = tf.reshape(out, *shape[:-2], 4) return out @@ -1121,6 +1137,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): return out +def _mask_to_rle_tf(input_mask: "tf.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = tf.where(diff) + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] + out.append({"size": [height, width], "counts": counts}) + return out + + def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" height, width = rle["size"] @@ -1137,7 +1176,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): """ - Perform NMS (Non Maxium Suppression) on the outputs. + Perform NMS (Non Maximum Suppression) on the outputs. Args: rle_masks (`torch.Tensor`): @@ -1162,3 +1201,33 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh= masks = [_rle_to_mask(rle) for rle in rle_masks] return masks, iou_scores, rle_masks, mask_boxes + +def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`tf.Tensor`): + binary masks in the RLE format + iou_scores (`tf.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`tf.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + breakpoint() + print() # Need to check the input shapes here so I know where to pad them + keep_by_nms = tf.image.combined_non_max_suppression( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes \ No newline at end of file diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index cfe8011616d1..8f7783a61198 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original with GPT-4. +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a discrepancy, the original file should be regarded as the 'reference' version. """ @@ -37,7 +37,7 @@ _CONFIG_FOR_DOC = "SamConfig" _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" -SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ +TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/sam-vit-huge", "facebook/sam-vit-large", "facebook/sam-vit-base", diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index ca3766b5a0cc..60848c65b4f2 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -71,10 +71,10 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional def functional_layernorm(inputs, weight, bias, epsilon=1e-5): - # Matt: This is a very simplified functional layernorm, designed to duplicate + # This is a very simplified functional layernorm, designed to duplicate # the functionality of PyTorch nn.functional.layer_norm when this is needed to port # models in Transformers. It assumes the dimension to be normalized is always the last one. - # If you need it to handle multiple dimensions, yell at me and I'll patch it. + # If you need it to handle multiple dimensions, yell at me (Matt) and I'll patch it. # Calculate the moments on the last axis (layer activations). mean, variance = tf.nn.moments(inputs, -1, keepdims=True) @@ -92,6 +92,24 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5): return outputs +def flatten(input, start_dim=0, end_dim=-1): + # Replicates the behavior of torch.flatten in TF + + # If end_dim or start_dim is negative, count them from the end + if end_dim < 0: + end_dim += input.shape.rank + if start_dim < 0: + start_dim += input.shape.rank + + if start_dim == end_dim: + return input + + in_shape = tf.shape(input) + flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) + out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) + return tf.reshape(input, out_shape) + + def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: """ Invert an attention mask (e.g., switches 0. and 1.). From 6b38c835c41c7199ca6322036386d898e3da37e5 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 26 Apr 2023 14:34:14 +0100 Subject: [PATCH 06/49] Add the extra processor and tests --- .../models/sam/image_processing_sam.py | 11 +- src/transformers/models/sam/processing_sam.py | 31 +- .../models/sam/processing_tf_sam.py | 248 ++++++ tests/models/sam/test_modeling_tf_sam.py | 745 ++++++++++++++++++ tests/models/sam/test_processor_tf_sam.py | 115 +++ 5 files changed, 1140 insertions(+), 10 deletions(-) create mode 100644 src/transformers/models/sam/processing_tf_sam.py create mode 100644 tests/models/sam/test_modeling_tf_sam.py create mode 100644 tests/models/sam/test_processor_tf_sam.py diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 038b8fc19117..3609e1ccddb3 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -454,18 +454,21 @@ def post_process_masks_tf( (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. """ - requires_backends(self, ["tensorflow"]) + requires_backends(self, ["tf"]) pad_size = self.pad_size if pad_size is None else pad_size target_image_size = (pad_size["height"], pad_size["width"]) output_masks = [] for i, original_size in enumerate(original_sizes): - interpolated_mask = tf.image.resize(masks[i], target_image_size, method="bilinear") + # tf.image expects NHWC, we transpose the NCHW inputs for it + mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) + interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") if binarize: interpolated_mask = interpolated_mask > mask_threshold - output_masks.append(interpolated_mask) + # And then we transpose them back at the end + output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2])) return output_masks @@ -684,7 +687,7 @@ def filter_masks_tf( The offset for the stability score used in the `_compute_stability_score` method. """ - requires_backends(self, ["tensorflow"]) + requires_backends(self, ["tf"]) original_height, original_width = original_size iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index b5ae51d7db29..24f3cc8306dc 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -22,12 +22,15 @@ from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_torch_available +from ...utils import TensorType, is_torch_available, is_tf_available if is_torch_available(): import torch +if is_tf_available(): + import tensorflow as tf + class SamProcessor(ProcessorMixin): r""" @@ -72,7 +75,7 @@ def __call__( # pop arguments that are not used in the foward but used nevertheless original_sizes = encoding_image_processor["original_sizes"] - if isinstance(original_sizes, torch.Tensor): + if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor original_sizes = original_sizes.numpy() input_points, input_labels, input_boxes = self._check_and_preprocess_points( @@ -139,18 +142,30 @@ def _normalize_and_convert( input_boxes = torch.from_numpy(input_boxes) # boxes batch size of 1 by default input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + elif return_tensors == "tf": + input_boxes = tf.convert_to_tensor(input_boxes) + # boxes batch size of 1 by default + input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes encoding_image_processor.update({"input_boxes": input_boxes}) if input_points is not None: if return_tensors == "pt": input_points = torch.from_numpy(input_points) # point batch size of 1 by default input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + elif return_tensors == "tf": + input_points = tf.convert_to_tensor(input_points) + # point batch size of 1 by default + input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points encoding_image_processor.update({"input_points": input_points}) if input_labels is not None: if return_tensors == "pt": input_labels = torch.from_numpy(input_labels) # point batch size of 1 by default input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + elif return_tensors == "tf": + input_labels = tf.convert_to_tensor(input_labels) + # point batch size of 1 by default + input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels encoding_image_processor.update({"input_labels": input_labels}) return encoding_image_processor @@ -204,7 +219,7 @@ def _check_and_preprocess_points( it is converted to a `numpy.ndarray` and then to a `list`. """ if input_points is not None: - if isinstance(input_points, torch.Tensor): + if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor input_points = input_points.numpy().tolist() if not isinstance(input_points, list) or not isinstance(input_points[0], list): @@ -214,7 +229,7 @@ def _check_and_preprocess_points( input_points = None if input_labels is not None: - if isinstance(input_labels, torch.Tensor): + if hasattr(input_labels, "numpy"): input_labels = input_labels.numpy().tolist() if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): @@ -224,7 +239,7 @@ def _check_and_preprocess_points( input_labels = None if input_boxes is not None: - if isinstance(input_boxes, torch.Tensor): + if hasattr(input_boxes, "numpy"): input_boxes = input_boxes.numpy().tolist() if ( @@ -245,4 +260,8 @@ def model_input_names(self): return list(dict.fromkeys(image_processor_input_names)) def post_process_masks(self, *args, **kwargs): - return self.image_processor.post_process_masks(*args, **kwargs) + return_tensors = kwargs.pop("return_tensors", "pt") + if return_tensors == "pt": + return self.image_processor.post_process_masks(*args, **kwargs) + elif return_tensors == "tf": + return self.image_processor.post_process_masks_tf(*args, **kwargs) diff --git a/src/transformers/models/sam/processing_tf_sam.py b/src/transformers/models/sam/processing_tf_sam.py new file mode 100644 index 000000000000..83356f41ee16 --- /dev/null +++ b/src/transformers/models/sam/processing_tf_sam.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAM. +""" +from copy import deepcopy +from typing import Optional, Union + +import numpy as np + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, is_torch_available + + +if is_torch_available(): + import torch + + +class SamProcessor(ProcessorMixin): + r""" + Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of + [`~SamImageProcessor.__call__`] for more information. + + Args: + image_processor (`SamImageProcessor`): + An instance of [`SamImageProcessor`]. The image processor is a required input. + """ + attributes = ["image_processor"] + image_processor_class = "SamImageProcessor" + + def __init__(self, image_processor): + super().__init__(image_processor) + self.current_processor = self.image_processor + self.point_pad_value = -10 + self.target_size = self.image_processor.size["longest_edge"] + + def __call__( + self, + images=None, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + """ + encoding_image_processor = self.image_processor( + images, + return_tensors=return_tensors, + **kwargs, + ) + + # pop arguments that are not used in the foward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + + if isinstance(original_sizes, torch.Tensor): + original_sizes = original_sizes.numpy() + + input_points, input_labels, input_boxes = self._check_and_preprocess_points( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + ) + + encoding_image_processor = self._normalize_and_convert( + encoding_image_processor, + original_sizes, + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + return_tensors=return_tensors, + ) + + return encoding_image_processor + + def _normalize_and_convert( + self, + encoding_image_processor, + original_sizes, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors="pt", + ): + if input_points is not None: + if len(original_sizes) != len(input_points): + input_points = [ + self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points + ] + else: + input_points = [ + self._normalize_coordinates(self.target_size, point, original_size) + for point, original_size in zip(input_points, original_sizes) + ] + # check that all arrays have the same shape + if not all([point.shape == input_points[0].shape for point in input_points]): + if input_labels is not None: + input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) + + input_points = np.array(input_points) + + if input_labels is not None: + input_labels = np.array(input_labels) + + if input_boxes is not None: + if len(original_sizes) != len(input_boxes): + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) + for box in input_boxes + ] + else: + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) + for box, original_size in zip(input_boxes, original_sizes) + ] + input_boxes = np.array(input_boxes) + + if input_boxes is not None: + if return_tensors == "pt": + input_boxes = torch.from_numpy(input_boxes) + # boxes batch size of 1 by default + input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + encoding_image_processor.update({"input_boxes": input_boxes}) + if input_points is not None: + if return_tensors == "pt": + input_points = torch.from_numpy(input_points) + # point batch size of 1 by default + input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + encoding_image_processor.update({"input_points": input_points}) + if input_labels is not None: + if return_tensors == "pt": + input_labels = torch.from_numpy(input_labels) + # point batch size of 1 by default + input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + encoding_image_processor.update({"input_labels": input_labels}) + + return encoding_image_processor + + def _pad_points_and_labels(self, input_points, input_labels): + r""" + The method pads the 2D points and labels to the maximum number of points in the batch. + """ + expected_nb_points = max([point.shape[0] for point in input_points]) + processed_input_points = [] + for i, point in enumerate(input_points): + if point.shape[0] != expected_nb_points: + point = np.concatenate( + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 + ) + input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) + processed_input_points.append(point) + input_points = processed_input_points + return input_points, input_labels + + def _normalize_coordinates( + self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _check_and_preprocess_points( + self, + input_points=None, + input_labels=None, + input_boxes=None, + ): + r""" + Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they + are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, + it is converted to a `numpy.ndarray` and then to a `list`. + """ + if input_points is not None: + if isinstance(input_points, torch.Tensor): + input_points = input_points.numpy().tolist() + + if not isinstance(input_points, list) and not isinstance(input_points[0], list): + raise ValueError("Input points must be a list of list of floating integers.") + input_points = [np.array(input_point) for input_point in input_points] + else: + input_points = None + + if input_labels is not None: + if isinstance(input_labels, torch.Tensor): + input_labels = input_labels.numpy().tolist() + + if not isinstance(input_labels, list) and not isinstance(input_labels[0], list): + raise ValueError("Input labels must be a list of list integers.") + input_labels = [np.array(label) for label in input_labels] + else: + input_labels = None + + if input_boxes is not None: + if isinstance(input_boxes, torch.Tensor): + input_boxes = input_boxes.numpy().tolist() + + if ( + not isinstance(input_boxes, list) + and not isinstance(input_boxes[0], list) + and not isinstance(input_boxes[0][0], list) + ): + raise ValueError("Input boxes must be a list of list of list of floating integers.") + input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] + else: + input_boxes = None + + return input_points, input_labels, input_boxes + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(image_processor_input_names)) + + def post_process_masks(self, *args, **kwargs): + return self.image_processor.post_process_masks(*args, **kwargs) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py new file mode 100644 index 000000000000..e51eb07dd311 --- /dev/null +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -0,0 +1,745 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch SAM model. """ + + +import inspect +import unittest + +import requests + +from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import SamModel, SamProcessor + from transformers.models.sam.modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + +class SamPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return SamPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class SamMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + def get_config(self): + return SamMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsample_rate=self.attention_downsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + + return config, dummy_inputs + + +class SamModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = SamPromptEncoderTester() + self.mask_decoder_tester = SamMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = SamVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return SamConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = SamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = SamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12)) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = SamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SamModel,) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {} + ) + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + + # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working + def is_pipeline_test_to_skip( + self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name + ): + return True + + def setUp(self): + self.model_tester = SamModelTester(self) + self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) + self.prompt_encoder_config_tester = ConfigTester( + self, + config_class=SamPromptEncoderConfig, + has_text_modality=False, + num_attention_heads=12, + num_hidden_layers=2, + ) + self.mask_decoder_config_tester = ConfigTester( + self, config_class=SamMaskDecoderConfig, has_text_modality=False + ) + + def test_config(self): + self.vision_config_tester.run_common_tests() + self.prompt_encoder_config_tester.run_common_tests() + self.mask_decoder_config_tester.run_common_tests() + + @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="SamModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="SamModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SamModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = SamModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@slow +class SamModelIntegrationTest(unittest.TestCase): + def test_inference_mask_generation_no_point(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + inputs = processor(images=raw_image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) + + def test_inference_mask_generation_one_point_one_bb(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[650, 900, 1000, 1250]] + input_points = [[[820, 1080]]] + + inputs = processor( + images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4)) + + def test_inference_mask_generation_batched_points_batched_images(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [ + [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + ] + + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze().cpu() + + EXPECTED_SCORES = torch.tensor( + [ + [ + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + ], + [ + [0.8405, 0.6292, 0.3840], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + ], + ] + ) + self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[620, 900, 1000, 1255]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9689), atol=1e-4)) + + def test_inference_mask_generation_one_point(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4)) + + # With no label + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4)) + + def test_inference_mask_generation_two_points(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4)) + + # no labels + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4)) + + def test_inference_mask_generation_two_points_batched(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9936), atol=1e-4)) + self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9716), atol=1e-4)) + + def test_inference_mask_generation_one_box(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8686), atol=1e-4)) + + def test_inference_mask_generation_batched_image_one_point(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores_batched = outputs.iou_scores.squeeze() + + input_points = [[[220, 470]]] + + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores_single = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + + def test_inference_mask_generation_two_points_point_batch(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() + # fmt: on + + input_points = input_points.unsqueeze(0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 2, 3)) + torch.testing.assert_allclose( + iou_scores, torch.tensor([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), atol=1e-4, rtol=1e-4 + ) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() + EXPECTED_IOU = torch.tensor([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) + # fmt: on + input_boxes = input_boxes.unsqueeze(0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 3, 3)) + torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) diff --git a/tests/models/sam/test_processor_tf_sam.py b/tests/models/sam/test_processor_tf_sam.py new file mode 100644 index 000000000000..004847b51631 --- /dev/null +++ b/tests/models/sam/test_processor_tf_sam.py @@ -0,0 +1,115 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_torchvision, require_vision, require_tf +from transformers.utils import is_torch_available, is_vision_available, is_tf_available + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoProcessor, SamImageProcessor, SamProcessor + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + + +@require_vision +@require_tf +class TFSamProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_additional_features(self): + processor = SamProcessor(image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, SamImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor + input_feat_extract.pop("reshaped_input_sizes") # pop original_sizes as it is popped in the processor + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + @require_tf + def test_post_process_masks(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + dummy_masks = [tf.ones((1, 3, 5, 5))] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + masks = processor.post_process_masks( + dummy_masks, tf.convert_to_tensor(original_sizes), tf.convert_to_tensor(reshaped_input_size), return_tensors="tf" + ) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + # should also work with np + dummy_masks = [np.ones((1, 3, 5, 5))] + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf") + + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + dummy_masks = [[1, 0], [0, 1]] + with self.assertRaises(tf.errors.InvalidArgumentError): + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf") From ebc235c3fe2c3b9643cc472dc25c114f213970c4 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 26 Apr 2023 15:30:20 +0100 Subject: [PATCH 07/49] rebase to main --- docs/source/en/model_doc/sam.mdx | 2 +- .../models/auto/modeling_tf_auto.py | 6 +- src/transformers/models/sam/__init__.py | 8 +- .../models/sam/image_processing_sam.py | 10 +- .../models/sam/modeling_tf_sam.py | 32 +- src/transformers/models/sam/processing_sam.py | 2 +- src/transformers/utils/dummy_tf_objects.py | 17 + tests/models/sam/test_modeling_tf_sam.py | 296 +++++++----------- tests/models/sam/test_processor_tf_sam.py | 19 +- 9 files changed, 177 insertions(+), 215 deletions(-) diff --git a/docs/source/en/model_doc/sam.mdx b/docs/source/en/model_doc/sam.mdx index 33dd6857d10c..ac8aa4a57302 100644 --- a/docs/source/en/model_doc/sam.mdx +++ b/docs/source/en/model_doc/sam.mdx @@ -103,5 +103,5 @@ Resources: ## TFSamModel -[[autodoc]] SamModel +[[autodoc]] TFSamModel - call \ No newline at end of file diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 3f15396c34d5..bfc29f2dd35f 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -427,7 +427,7 @@ ("mobilebert", "TFMobileBertForNextSentencePrediction"), ] ) -TFMODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( +TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( [ ("sam", "TFSamModel"), ] @@ -482,7 +482,9 @@ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) -TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) +TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES +) class TFAutoModelForMaskGeneration(_BaseAutoModelClass): diff --git a/src/transformers/models/sam/__init__.py b/src/transformers/models/sam/__init__.py index 2c9c39c7f9af..e8006e89e0f1 100644 --- a/src/transformers/models/sam/__init__.py +++ b/src/transformers/models/sam/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available, is_tf_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = { diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 3609e1ccddb3..c27e84d0078c 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -54,7 +54,8 @@ if is_tf_available(): import tensorflow as tf from tensorflow.experimental import numpy as tnp - from ...tf_utils import shape_list, flatten + + from ...tf_utils import flatten, shape_list logger = logging.get_logger(__name__) @@ -451,8 +452,8 @@ def post_process_masks_tf( The target size the images were padded to before being passed to the model. If None, the target size is assumed to be the processor's `pad_size`. Returns: - (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) - is given by original_size. + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is + given by original_size. """ requires_backends(self, ["tf"]) pad_size = self.pad_size if pad_size is None else pad_size @@ -1205,6 +1206,7 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh= return masks, iou_scores, rle_masks, mask_boxes + def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): """ Perform NMS (Non Maximum Suppression) on the outputs. @@ -1233,4 +1235,4 @@ def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thre mask_boxes = mask_boxes[keep_by_nms] masks = [_rle_to_mask(rle) for rle in rle_masks] - return masks, iou_scores, rle_masks, mask_boxes \ No newline at end of file + return masks, iou_scores, rle_masks, mask_boxes diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 8f7783a61198..8f3248b5641e 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. -In the event of a discrepancy, the original file should be regarded as the 'reference' version. +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a +discrepancy, the original file should be regarded as the 'reference' version. """ import collections @@ -27,8 +27,8 @@ from ...activations_tf import ACT2FN from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_utils import TFPreTrainedModel, shape_list -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...tf_utils import functional_layernorm +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -57,8 +57,8 @@ class TFSamVisionEncoderOutput(ModelOutput): last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -86,8 +86,8 @@ class TFSamImageSegmentationOutput(ModelOutput): pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): The predicted low resolutions masks. Needs to be post-processed by the processor vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -1083,9 +1083,9 @@ def _init_weights(self, module): library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) - This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. - Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to general usage - and behavior. + This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to + general usage and behavior. Parameters: config ([`SamConfig`]): Model configuration class with all the parameters of the model. @@ -1102,9 +1102,9 @@ def _init_weights(self, module): input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much better results. The points can be obtained by passing a list of list of list to the processor that will - create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the - second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict - per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second + dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per + input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the @@ -1125,9 +1125,9 @@ def _init_weights(self, module): input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, - that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch - size, the number of boxes per image and the coordinates of the top left and botton right point of the box. - In the order (`x1`, `y1`, `x2`, `y2`): + that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, + the number of boxes per image and the coordinates of the top left and botton right point of the box. In the + order (`x1`, `y1`, `x2`, `y2`): - `x1`: the x coordinate of the top left point of the input box - `y1`: the y coordinate of the top left point of the input box diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 24f3cc8306dc..a658030d36ff 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -22,7 +22,7 @@ from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_torch_available, is_tf_available +from ...utils import TensorType, is_tf_available, is_torch_available if is_torch_available(): diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 2a043f50f350..658d7f689fce 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2317,6 +2317,23 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFSamModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSamPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index e51eb07dd311..17dd54f41b4a 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -18,23 +18,24 @@ import inspect import unittest +import numpy as np import requests from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig -from transformers.testing_utils import require_torch, slow, torch_device -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_tf, slow +from transformers.utils import is_tf_available, is_vision_available from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_modeling_common import ModelTesterMixin +from ...test_modeling_tf_common import floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin -if is_torch_available(): - import torch - from torch import nn +if is_tf_available(): + import tensorflow as tf - from transformers import SamModel, SamProcessor - from transformers.models.sam.modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers import SamProcessor, TFSamModel + from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -124,7 +125,7 @@ def prepare_config_and_inputs(self): return config, dummy_inputs -class SamModelTester: +class TFSamModelTester: def __init__( self, parent, @@ -231,44 +232,34 @@ def get_config(self): ) def create_and_check_model(self, config, pixel_values): - model = SamModel(config=config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model(pixel_values) + model = TFSamModel(config=config) + result = model(pixel_values) self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) def create_and_check_get_image_features(self, config, pixel_values): - model = SamModel(config=config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model.get_image_embeddings(pixel_values) + model = TFSamModel(config=config) + result = model.get_image_embeddings(pixel_values) self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12)) def create_and_check_get_image_hidden_states(self, config, pixel_values): - model = SamModel(config=config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model.vision_encoder( - pixel_values, - output_hidden_states=True, - return_dict=True, - ) + model = TFSamModel(config=config) + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) # after computing the convolutional features expected_hidden_states_shape = (self.batch_size, 12, 12, 36) self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) - with torch.no_grad(): - result = model.vision_encoder( - pixel_values, - output_hidden_states=True, - return_dict=False, - ) + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) # after computing the convolutional features expected_hidden_states_shape = (self.batch_size, 12, 12, 36) @@ -282,22 +273,20 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -@require_torch -class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): +@require_tf +class TFSamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, attention_mask and seq_length. """ - all_model_classes = (SamModel,) if is_torch_available() else () + all_model_classes = (TFSamModel,) if is_tf_available() else () pipeline_model_mapping = ( - {"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {} + {"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {} ) - fx_compatible = False test_pruning = False test_resize_embeddings = False test_head_masking = False - test_torchscript = False # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working def is_pipeline_test_to_skip( @@ -306,7 +295,7 @@ def is_pipeline_test_to_skip( return True def setUp(self): - self.model_tester = SamModelTester(self) + self.model_tester = TFSamModelTester(self) self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) self.prompt_encoder_config_tester = ConfigTester( self, @@ -333,16 +322,16 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) - self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer)) x = model.get_output_embeddings() - self.assertTrue(x is None or isinstance(x, nn.Linear)) + self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense)) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) - signature = inspect.signature(model.forward) + signature = inspect.signature(model.call) # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] @@ -377,10 +366,7 @@ def test_attention_outputs(self): inputs_dict["output_hidden_states"] = False config.return_dict = True model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) vision_attentions = outputs.vision_attentions self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) @@ -392,10 +378,7 @@ def test_attention_outputs(self): del inputs_dict["output_attentions"] config.output_attentions = True model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) vision_attentions = outputs.vision_attentions self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) @@ -438,8 +421,8 @@ def test_hidden_states_output(self): @slow def test_model_from_pretrained(self): - for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = SamModel.from_pretrained(model_name) + for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFSamModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -458,64 +441,48 @@ def prepare_dog_img(): @slow class SamModelIntegrationTest(unittest.TestCase): def test_inference_mask_generation_no_point(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() - inputs = processor(images=raw_image, return_tensors="pt").to(torch_device) + inputs = processor(images=raw_image, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=1e-4)) def test_inference_mask_generation_one_point_one_bb(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_boxes = [[650, 900, 1000, 1250]] input_points = [[[820, 1080]]] - inputs = processor( - images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt" - ).to(torch_device) + inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) scores = outputs.iou_scores.squeeze() - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=1e-4)) def test_inference_mask_generation_batched_points_batched_images(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_points = [ [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], ] - inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to( - torch_device - ) + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) scores = outputs.iou_scores.squeeze().cpu() - EXPECTED_SCORES = torch.tensor( + EXPECTED_SCORES = np.array( [ [ [0.9673, 0.9441, 0.9084], @@ -531,15 +498,12 @@ def test_inference_mask_generation_batched_points_batched_images(self): ], ] ) - self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3)) def test_inference_mask_generation_one_point_one_bb_zero(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_boxes = [[620, 900, 1000, 1255]] input_points = [[[820, 1080]]] @@ -550,196 +514,160 @@ def test_inference_mask_generation_one_point_one_bb_zero(self): input_boxes=input_boxes, input_points=input_points, input_labels=labels, - return_tensors="pt", - ).to(torch_device) + return_tensors="tf", + ) - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9689), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4)) def test_inference_mask_generation_one_point(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_points = [[[400, 650]]] input_labels = [[1]] - inputs = processor( - images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" - ).to(torch_device) + inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4)) # With no label input_points = [[[400, 650]]] - inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) scores = outputs.iou_scores.squeeze() - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4)) def test_inference_mask_generation_two_points(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_points = [[[400, 650], [800, 650]]] input_labels = [[1, 1]] - inputs = processor( - images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" - ).to(torch_device) + inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) # no labels - inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) def test_inference_mask_generation_two_points_batched(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_points = [[[400, 650], [800, 650]], [[400, 650]]] input_labels = [[1, 1], [1]] inputs = processor( - images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt" - ).to(torch_device) + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf" + ) - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) scores = outputs.iou_scores.squeeze() - self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9936), atol=1e-4)) - self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9716), atol=1e-4)) + self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4)) + self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4)) def test_inference_mask_generation_one_box(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() input_boxes = [[[75, 275, 1725, 850]]] - inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores = tf.squeeze(outputs) - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8686), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4)) def test_inference_mask_generation_batched_image_one_point(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() raw_dog_image = prepare_dog_img() input_points = [[[820, 1080]], [[220, 470]]] - inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( - torch_device - ) + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores_batched = outputs.iou_scores.squeeze() + outputs = model(**inputs) + scores_batched = tf.squeeze(outputs.iou_scores) input_points = [[[220, 470]]] - inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) - scores_single = outputs.iou_scores.squeeze() - self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + outputs = model(**inputs) + scores_single = tf.squeeze(outputs.iou_scores) + self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4)) def test_inference_mask_generation_two_points_point_batch(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() # fmt: off - input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() + input_points = tf.conver_to_tensor([[[400, 650]], [[220, 470]]]) # fmt: on - input_points = input_points.unsqueeze(0) + input_points = tf.expand_dims(input_points, 0) - inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + inputs = processor(raw_image, input_points=input_points, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) - iou_scores = outputs.iou_scores.cpu() + iou_scores = outputs.iou_scores self.assertTrue(iou_scores.shape == (1, 2, 3)) - torch.testing.assert_allclose( - iou_scores, torch.tensor([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), atol=1e-4, rtol=1e-4 + self.assertTrue( + np.allclose(iou_scores.numpy()), + np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), + atol=1e-4, + rtol=1e-4, ) def test_inference_mask_generation_three_boxes_point_batch(self): - model = SamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - model.to(torch_device) - model.eval() - raw_image = prepare_image() # fmt: off - input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() - EXPECTED_IOU = torch.tensor([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) + input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]) + EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) # fmt: on - input_boxes = input_boxes.unsqueeze(0) + input_boxes = tf.expand_dims(input_boxes, 0) - inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf") - with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs) - iou_scores = outputs.iou_scores.cpu() + iou_scores = outputs.iou_scores self.assertTrue(iou_scores.shape == (1, 3, 3)) - torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + self.assertTrue(np.allclose(iou_scores.numpy()), EXPECTED_IOU, atol=1e-4, rtol=1e-4) diff --git a/tests/models/sam/test_processor_tf_sam.py b/tests/models/sam/test_processor_tf_sam.py index 004847b51631..c66aa4b89a4c 100644 --- a/tests/models/sam/test_processor_tf_sam.py +++ b/tests/models/sam/test_processor_tf_sam.py @@ -17,8 +17,8 @@ import numpy as np -from transformers.testing_utils import require_torch, require_torchvision, require_vision, require_tf -from transformers.utils import is_torch_available, is_vision_available, is_tf_available +from transformers.testing_utils import require_tf, require_vision +from transformers.utils import is_tf_available, is_torch_available, is_vision_available if is_vision_available(): @@ -27,7 +27,7 @@ from transformers import AutoProcessor, SamImageProcessor, SamProcessor if is_torch_available(): - import torch + pass if is_tf_available(): import tensorflow as tf @@ -100,16 +100,23 @@ def test_post_process_masks(self): self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) masks = processor.post_process_masks( - dummy_masks, tf.convert_to_tensor(original_sizes), tf.convert_to_tensor(reshaped_input_size), return_tensors="tf" + dummy_masks, + tf.convert_to_tensor(original_sizes), + tf.convert_to_tensor(reshaped_input_size), + return_tensors="tf", ) self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) # should also work with np dummy_masks = [np.ones((1, 3, 5, 5))] - masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf") + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) dummy_masks = [[1, 0], [0, 1]] with self.assertRaises(tf.errors.InvalidArgumentError): - masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf") + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) From b969a9d57a46827ecf0858f734cb8ee3045efa50 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 28 Apr 2023 18:57:05 +0100 Subject: [PATCH 08/49] Add all the needed fixes to the GPT code --- src/transformers/models/sam/modeling_sam.py | 6 +- .../models/sam/modeling_tf_sam.py | 286 ++++++++++-------- tests/models/sam/test_modeling_tf_sam.py | 72 +++-- 3 files changed, 213 insertions(+), 151 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index bf14a4b2413a..5346d3706773 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -668,11 +668,11 @@ def forward( Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (`torch.Tensor`, **optionnal**): + points (`torch.Tensor`, **optional**): point coordinates and labels to embed. - boxes (`torch.Tensor`, **optionnal**): + boxes (`torch.Tensor`, **optional**): boxes to embed - masks (`torch.Tensor`, **optionnal**): + masks (`torch.Tensor`, **optional**): masks to embed """ sparse_embeddings = None diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 8f3248b5641e..839fe215cebb 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -26,8 +26,8 @@ from ...activations_tf import ACT2FN from ...modeling_tf_outputs import TFBaseModelOutput -from ...modeling_tf_utils import TFPreTrainedModel, shape_list -from ...tf_utils import functional_layernorm +from ...modeling_tf_utils import TFPreTrainedModel, shape_list, unpack_inputs +from ...tf_utils import functional_layernorm, flatten from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -131,7 +131,7 @@ def __init__(self, config, **kwargs): self.num_patches = num_patches self.projection = tf.keras.layers.Conv2D( - hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection", data_format="channels_first" ) def call(self, pixel_values): @@ -170,13 +170,16 @@ class TFSamLayerNorm(tf.keras.layers.Layer): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): super().__init__(**kwargs) - self.weight = self.add_weight(shape=normalized_shape, initializer="ones", name="weight") - self.bias = self.add_weight(shape=normalized_shape, initializer="zeros", name="bias") self.eps = eps self.data_format = data_format + self.normalized_shape = normalized_shape if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError(f"Unsupported data format: {self.data_format}") - self.normalized_shape = (normalized_shape,) + + def build(self, input_shape): + self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") + self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") + super().build(input_shape) def call(self, x: tf.Tensor) -> tf.Tensor: if self.data_format == "channels_last": @@ -357,7 +360,7 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.layers = [] for i in range(self.num_hidden_layers): - self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_{i}")) + self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") self.layer_norm_final_attn = tf.keras.layers.LayerNormalization( @@ -384,23 +387,8 @@ def call( if image_embeddings is None: raise ValueError("You have to specify an image_embedding") - image_embeddings = tf.expand_dims( - tf.transpose( - tf.reshape(image_embeddings, (shape_list(image_embeddings)[0], -1, shape_list(image_embeddings)[-1])), - (0, 2, 1), - ), - 1, - ) - image_positional_embeddings = tf.expand_dims( - tf.transpose( - tf.reshape( - image_positional_embeddings, - (shape_list(image_positional_embeddings)[0], -1, shape_list(image_positional_embeddings)[-1]), - ), - (0, 2, 1), - ), - 1, - ) + image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] + image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] # Prepare queries queries = point_embeddings @@ -440,7 +428,7 @@ def __init__( self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") self.layers = [ - tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layer_{i}") + tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") for i in range(num_layers - 2) ] self.sigmoid_output = sigmoid_output @@ -466,16 +454,13 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.num_multimask_outputs = config.num_multimask_outputs self.num_mask_tokens = config.num_multimask_outputs + 1 - self.iou_token = tf.keras.layers.Embedding(1, self.hidden_size, name="iou_token") - self.mask_tokens = tf.keras.layers.Embedding(self.num_mask_tokens, self.hidden_size, name="mask_tokens") - - self.transformer = TFSamTwoWayTransformer(config) + self.transformer = TFSamTwoWayTransformer(config, name="transformer") self.upscale_conv1 = tf.keras.layers.Conv2DTranspose( - self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1" + self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" ) self.upscale_conv2 = tf.keras.layers.Conv2DTranspose( - self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2" + self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" ) self.upscale_layer_norm = TFSamLayerNorm( self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" @@ -483,14 +468,20 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): self.activation = tf.nn.gelu mlps_list = [] - for _ in range(self.num_mask_tokens): - mlps_list += [TFSamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] - self.output_hypernetworks_mlps = tf.keras.layers.LayerList(mlps_list) + for i in range(self.num_mask_tokens): + mlps_list += [TFSamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3, name=f"output_hypernetworks_mlps_._{i}")] + self.output_hypernetworks_mlps = mlps_list self.iou_prediction_head = TFSamFeedForward( - self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth, name="iou_prediction_head" ) + def build(self, input_shape): + self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) + self.mask_tokens = self.add_weight(shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", + trainable=True) + super().build(input_shape) + def call( self, image_embeddings: tf.Tensor, @@ -501,15 +492,18 @@ def call( output_attentions: Optional[bool] = None, ) -> Tuple[tf.Tensor, tf.Tensor]: batch_size, num_channels, height, width = shape_list(image_embeddings) - point_batch_size = max(1, shape_list(sparse_prompt_embeddings)[1]) - output_tokens = tf.concat([self.iou_token.embeddings, self.mask_tokens.embeddings], axis=0) - output_tokens = tf.tile(tf.expand_dims(output_tokens, axis=0), [batch_size, point_batch_size, 1, 1]) + point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) - if tf.reduce_sum(sparse_prompt_embeddings) != 0: + output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) + output_tokens = tf.tile(output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]) # Should be (batch_size, point_size, 5, 32) + + # Matt: I think this sum is actually checking that the sparse prompt embeddings aren't an empty tensor + # with shape[1] == 0, so I'm going to replace this + if sparse_prompt_embeddings.shape[1] != 0: tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) else: tokens = output_tokens - point_embeddings = tf.cast(tokens, self.iou_token.embeddings.dtype) + point_embeddings = tf.cast(tokens, self.iou_token.dtype) image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) @@ -565,9 +559,13 @@ class TFSamPositionalEmbedding(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.scale = config.hidden_size // 2 + self.config = config + + def build(self, input_shape): + # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? self.positional_embedding = self.add_weight( name="positional_embedding", - shape=(2, config.num_pos_feats), + shape=(2, self.config.num_pos_feats), initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), trainable=False, ) @@ -577,16 +575,7 @@ def call(self, input_coords, input_shape=None): coordinates = tf.identity(input_coords) if input_shape is not None: - coordinates = tf.tensor_scatter_nd_update( - coordinates, - indices=[[0, 0]], - updates=coordinates[:, :, :, 0] / input_shape[1], - ) - coordinates = tf.tensor_scatter_nd_update( - coordinates, - indices=[[0, 1]], - updates=coordinates[:, :, :, 1] / input_shape[0], - ) + coordinates = tf.stack([tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0]], axis=-1) # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coordinates = 2 * coordinates - 1 @@ -602,9 +591,9 @@ def __init__(self, config: SamPromptEncoderConfig, **kwargs): super().__init__(**kwargs) self.mask_input_channels = config.mask_input_channels // 4 self.activation = ACT2FN[config.hidden_act] - self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") - self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2") - self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") + self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1", data_format="channels_first") + self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2", data_format="channels_first") + self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3", data_format="channels_first") self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") @@ -624,17 +613,37 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): super().__init__(**kwargs) self.shared_embedding = shared_patch_embedding - self.mask_embed = TFSamMaskEmbedding(config) - self.no_mask_embed = tf.keras.layers.Embedding(1, config.hidden_size) + self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") + self.no_mask_embed = None self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) self.input_image_size = config.image_size - self.point_embed = [ - tf.keras.layers.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings) - ] + self.point_embed = [] self.hidden_size = config.hidden_size - self.not_a_point_embed = tf.keras.layers.Embedding(1, config.hidden_size) + self.not_a_point_embed = None + self.config = config + + def build(self, input_shape): + self.no_mask_embed = self.add_weight( + name="no_mask_embed.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + self.point_embed = [self.add_weight( + name=f"point_embed_._{i}.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) for i in range(self.config.num_point_embeddings)] + self.not_a_point_embed = self.add_weight( + name="not_a_point_embed.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + super().build(input_shape) def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: """Embeds point prompts.""" @@ -649,16 +658,15 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T input_shape = (self.input_image_size, self.input_image_size) point_embedding = self.shared_embedding(points, input_shape) - point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed.weights[0], point_embedding) + point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) point_embedding = tf.where( labels[..., None] != -10, point_embedding, tf.zeros_like(point_embedding), ) - - point_embedding = tf.where(labels == 0, point_embedding + self.point_embed[0].weights[0], point_embedding) - point_embedding = tf.where(labels == 1, point_embedding + self.point_embed[1].weights[0], point_embedding) + point_embedding = tf.where((labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding) + point_embedding = tf.where((labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding) return point_embedding def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: @@ -670,13 +678,14 @@ def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: corner_embedding = self.shared_embedding(coords, input_shape) corner_embedding += tf.where( tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, - self.point_embed[2].weights[0], - self.point_embed[3].weights[0], + self.point_embed[2][0], + self.point_embed[3][0], ) return corner_embedding def call( self, + batch_size: Optional[int], input_points: Optional[Tuple[tf.Tensor, tf.Tensor]], input_labels: Optional[tf.Tensor], input_boxes: Optional[tf.Tensor], @@ -686,15 +695,14 @@ def call( Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (`tf.Tensor`, **optionnal**): + points (`tf.Tensor`, **optional**): point coordinates and labels to embed. - boxes (`tf.Tensor`, **optionnal**): + boxes (`tf.Tensor`, **optional**): boxes to embed - masks (`tf.Tensor`, **optionnal**): + masks (`tf.Tensor`, **optional**): masks to embed """ sparse_embeddings = None - batch_size = 1 if input_points is not None: batch_size, point_batch_size = input_points.shape[:2] if input_labels is None: @@ -714,11 +722,11 @@ def call( if input_masks is not None: dense_embeddings = self.mask_embed(input_masks) else: - dense_embeddings = tf.reshape(self.no_mask_embed.weights[0], (1, -1, 1, 1)) + dense_embeddings = self.no_mask_embed[0] + dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) dense_embeddings = tf.tile( dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) ) - if sparse_embeddings is None: sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) @@ -735,9 +743,12 @@ def __init__(self, config, window_size, **kwargs) -> None: if window_size == 0 else (window_size, window_size) ) + self.input_size = input_size + self.num_attention_heads = config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = head_dim self.scale = head_dim**-0.5 self.dropout = config.attention_dropout @@ -749,13 +760,17 @@ def __init__(self, config, window_size, **kwargs) -> None: if input_size is None: raise ValueError("Input size must be provided if using relative positional encoding.") - # initialize relative positional embeddings + + def build(self, input_shape): + if self.input_size is not None: + # initialize relative positional embeddings self.rel_pos_h = self.add_weight( - shape=(2 * input_size[0] - 1, head_dim), initializer="zeros", name="rel_pos_h" + shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" ) self.rel_pos_w = self.add_weight( - shape=(2 * input_size[1] - 1, head_dim), initializer="zeros", name="rel_pos_w" + shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" ) + super().build(input_shape) def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: """ @@ -841,15 +856,13 @@ def add_decomposed_rel_pos( def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = tf.reshape( - self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1) - ).permute(2, 0, 3, 1, 4) + qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) + qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = tf.unstack( tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 ) - - attn_weights = (query * self.scale) @ tf.transpose(key, perm=(-2, -1)) + attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) if self.use_rel_pos: attn_weights = self.add_decomposed_rel_pos( @@ -858,9 +871,11 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: attn_weights = tf.nn.softmax(attn_weights, axis=-1) - attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + # attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + attn_probs = attn_weights - attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) + attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) attn_output = tf.reshape(attn_output, (batch_size, height, width, -1)) attn_output = self.proj(attn_output) @@ -952,10 +967,10 @@ def __init__(self, config: SamVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config - self.conv1 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=1, use_bias=False, name="conv1") + self.conv1 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=1, use_bias=False, name="conv1", data_format="channels_first") self.layer_norm1 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm1") self.conv2 = tf.keras.layers.Conv2D( - config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2" + config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2", data_format="channels_first" ) self.layer_norm2 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm2") @@ -975,34 +990,36 @@ def __init__(self, config: SamVisionConfig, **kwargs): self.config = config self.image_size = config.image_size - self.patch_embed = TFSamPatchEmbeddings(config) + self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") self.pos_embed = None - if config.use_abs_pos: - # Initialize absolute positional embedding with pretrain image size. - self.pos_embed = tf.Variable( - tf.zeros( - [ - 1, - config.image_size // config.patch_size, - config.image_size // config.patch_size, - config.hidden_size, - ] - ), - trainable=True, - ) self.layers = [] for i in range(config.num_hidden_layers): layer = TFSamVisionLayer( config, window_size=config.window_size if i not in config.global_attn_indexes else 0, + name=f"layers_._{i}", ) self.layers.append(layer) - self.neck = TFSamVisionNeck(config) + self.neck = TFSamVisionNeck(config, name="neck") - self.gradient_checkpointing = False + def build(self, input_shape): + if self.config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = self.add_weight( + shape=[ + 1, + self.config.image_size // self.config.patch_size, + self.config.image_size // self.config.patch_size, + self.config.hidden_size, + ], + initializer="zeros", + trainable=True, + name="pos_embed", + ) + super().build(input_shape) def get_input_embeddings(self): return self.patch_embed @@ -1066,16 +1083,38 @@ class TFSamPreTrainedModel(TFPreTrainedModel): base_model_prefix = "sam" main_input_name = "pixel_values" - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (tf.keras.layers.Dense, tf.keras.layers.Conv2D, tf.keras.layers.Conv2DTranspose)): - module.kernel.assign(tf.random.normal(module.kernel.shape, mean=0.0, stddev=std)) - if module.bias is not None: - module.bias.assign(tf.zeros(module.bias.shape)) - elif isinstance(module, tf.keras.layers.Embedding): - module.embeddings.assign(tf.random.normal(module.embeddings.shape, mean=0.0, stddev=std)) - if module.padding_idx is not None: - module.embeddings[module.padding_idx].assign(tf.zeros(module.embeddings[module.padding_idx].shape)) + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=(3, self.config.vision_config.num_channels, self.config.vision_config.image_size, self.config.vision_config.image_size), dtype=tf.float32 + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) SAM_START_DOCSTRING = r""" @@ -1168,13 +1207,11 @@ class TFSamModel(TFSamPreTrainedModel): def __init__(self, config, **kwargs) -> None: super().__init__(config, **kwargs) - self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config) - - self.vision_encoder = TFSamVisionEncoder(config.vision_config) - self.prompt_encoder = TFSamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding) - self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config) + self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") - self.post_init() + self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") + self.prompt_encoder = TFSamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder") + self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() @@ -1252,6 +1289,7 @@ def get_prompt_embeddings( ) return prompt_output + @unpack_inputs @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) def call( self, @@ -1298,10 +1336,9 @@ def call( point_batch_size, box_batch_size ) ) - image_positional_embeddings = self.get_image_wide_positional_embeddings() # repeat with batch size - batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) vision_attentions = None @@ -1334,6 +1371,7 @@ def call( ) sparse_embeddings, dense_embeddings = self.prompt_encoder( + batch_size=image_embeddings.shape[0], input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, @@ -1365,3 +1403,15 @@ def call( vision_attentions=vision_attentions, mask_decoder_attentions=mask_decoder_attentions, ) + + def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: + hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None + + return TFSamImageSegmentationOutput( + iou_scores=output.iou_scores, + pred_masks=output.pred_masks, + vision_hidden_states=hs, + vision_attentions=attns, + mask_decoder_attentions=output.mask_decoder_attentions, + ) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 17dd54f41b4a..f43ba67f4fcb 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -23,11 +23,13 @@ from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig from transformers.testing_utils import require_tf, slow -from transformers.utils import is_tf_available, is_vision_available +from transformers.utils import is_tf_available, is_vision_available, is_torch_available + +import tempfile +import os from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin -from ...test_modeling_tf_common import floats_tensor +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, is_pt_tf_cross_test from ...test_pipeline_mixin import PipelineTesterMixin @@ -37,12 +39,15 @@ from transformers import SamProcessor, TFSamModel from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST +if is_torch_available(): + import torch + if is_vision_available(): from PIL import Image -class SamPromptEncoderTester: +class TFSamPromptEncoderTester: def __init__( self, hidden_size=32, @@ -76,7 +81,7 @@ def prepare_config_and_inputs(self): return config, dummy_points -class SamMaskDecoderTester: +class TFSamMaskDecoderTester: def __init__( self, hidden_size=32, @@ -186,8 +191,8 @@ def __init__( num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 - self.prompt_encoder_tester = SamPromptEncoderTester() - self.mask_decoder_tester = SamMaskDecoderTester() + self.prompt_encoder_tester = TFSamPromptEncoderTester() + self.mask_decoder_tester = TFSamMaskDecoderTester() def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -274,7 +279,7 @@ def prepare_config_and_inputs_for_common(self): @require_tf -class TFSamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): +class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, attention_mask and seq_length. @@ -287,6 +292,7 @@ class TFSamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_pruning = False test_resize_embeddings = False test_head_masking = False + test_onnx = False # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working def is_pipeline_test_to_skip( @@ -422,9 +428,16 @@ def test_hidden_states_output(self): @slow def test_model_from_pretrained(self): for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = TFSamModel.from_pretrained(model_name) + model = TFSamModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) + def test_pt_tf_model_equivalence(self, allow_missing_keys=True): + super().test_pt_tf_model_equivalence(allow_missing_keys=True) + + @unittest.skip(reason="Temporary skip while we resolve other issues - do not merge until this is removed!") + def test_saved_model_creation(self): + pass + def prepare_image(): img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" @@ -441,7 +454,7 @@ def prepare_dog_img(): @slow class SamModelIntegrationTest(unittest.TestCase): def test_inference_mask_generation_no_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -449,11 +462,10 @@ def test_inference_mask_generation_no_point(self): outputs = model(**inputs) scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=1e-4)) def test_inference_mask_generation_one_point_one_bb(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -463,12 +475,12 @@ def test_inference_mask_generation_one_point_one_bb(self): inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf") outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + scores = tf.squeeze(outputs.iou_scores) self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=1e-4)) def test_inference_mask_generation_batched_points_batched_images(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -480,7 +492,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf") outputs = model(**inputs) - scores = outputs.iou_scores.squeeze().cpu() + scores = tf.squeeze(outputs.iou_scores) EXPECTED_SCORES = np.array( [ @@ -501,7 +513,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3)) def test_inference_mask_generation_one_point_one_bb_zero(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -523,7 +535,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4)) def test_inference_mask_generation_one_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -544,12 +556,12 @@ def test_inference_mask_generation_one_point(self): inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + scores = tf.squeeze(outputs.iou_scores) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4)) def test_inference_mask_generation_two_points(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -572,7 +584,7 @@ def test_inference_mask_generation_two_points(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) def test_inference_mask_generation_two_points_batched(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -585,13 +597,13 @@ def test_inference_mask_generation_two_points_batched(self): ) outputs = model(**inputs) - scores = outputs.iou_scores.squeeze() + scores = tf.squeeze(outputs.iou_scores) self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4)) self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4)) def test_inference_mask_generation_one_box(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -601,12 +613,12 @@ def test_inference_mask_generation_one_box(self): inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf") outputs = model(**inputs) - scores = tf.squeeze(outputs) + scores = tf.squeeze(outputs.iou_scores) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4)) def test_inference_mask_generation_batched_image_one_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -628,13 +640,13 @@ def test_inference_mask_generation_batched_image_one_point(self): self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4)) def test_inference_mask_generation_two_points_point_batch(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() # fmt: off - input_points = tf.conver_to_tensor([[[400, 650]], [[220, 470]]]) + input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]]) # fmt: on input_points = tf.expand_dims(input_points, 0) @@ -646,14 +658,14 @@ def test_inference_mask_generation_two_points_point_batch(self): iou_scores = outputs.iou_scores self.assertTrue(iou_scores.shape == (1, 2, 3)) self.assertTrue( - np.allclose(iou_scores.numpy()), + np.allclose(iou_scores.numpy(), np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), atol=1e-4, rtol=1e-4, - ) + )) def test_inference_mask_generation_three_boxes_point_batch(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -670,4 +682,4 @@ def test_inference_mask_generation_three_boxes_point_batch(self): iou_scores = outputs.iou_scores self.assertTrue(iou_scores.shape == (1, 3, 3)) - self.assertTrue(np.allclose(iou_scores.numpy()), EXPECTED_IOU, atol=1e-4, rtol=1e-4) + self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4)) From d3b13926e0458cad51624e80228cfea7e1fda3fd Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 28 Apr 2023 18:59:58 +0100 Subject: [PATCH 09/49] make fixup --- .../models/sam/modeling_tf_sam.py | 98 ++++++++++++++----- tests/models/sam/test_modeling_tf_sam.py | 21 ++-- 2 files changed, 82 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 839fe215cebb..7c3402a3db3e 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -27,7 +27,7 @@ from ...activations_tf import ACT2FN from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_utils import TFPreTrainedModel, shape_list, unpack_inputs -from ...tf_utils import functional_layernorm, flatten +from ...tf_utils import flatten, functional_layernorm from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -469,17 +469,30 @@ def __init__(self, config: SamMaskDecoderConfig, **kwargs): mlps_list = [] for i in range(self.num_mask_tokens): - mlps_list += [TFSamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3, name=f"output_hypernetworks_mlps_._{i}")] + mlps_list += [ + TFSamFeedForward( + self.hidden_size, + self.hidden_size, + self.hidden_size // 8, + 3, + name=f"output_hypernetworks_mlps_._{i}", + ) + ] self.output_hypernetworks_mlps = mlps_list self.iou_prediction_head = TFSamFeedForward( - self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth, name="iou_prediction_head" + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + name="iou_prediction_head", ) def build(self, input_shape): self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) - self.mask_tokens = self.add_weight(shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", - trainable=True) + self.mask_tokens = self.add_weight( + shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True + ) super().build(input_shape) def call( @@ -495,7 +508,9 @@ def call( point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) - output_tokens = tf.tile(output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]) # Should be (batch_size, point_size, 5, 32) + output_tokens = tf.tile( + output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] + ) # Should be (batch_size, point_size, 5, 32) # Matt: I think this sum is actually checking that the sparse prompt embeddings aren't an empty tensor # with shape[1] == 0, so I'm going to replace this @@ -575,7 +590,13 @@ def call(self, input_coords, input_shape=None): coordinates = tf.identity(input_coords) if input_shape is not None: - coordinates = tf.stack([tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0]], axis=-1) + coordinates = tf.stack( + [ + tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], + tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], + ], + axis=-1, + ) # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coordinates = 2 * coordinates - 1 @@ -591,9 +612,15 @@ def __init__(self, config: SamPromptEncoderConfig, **kwargs): super().__init__(**kwargs) self.mask_input_channels = config.mask_input_channels // 4 self.activation = ACT2FN[config.hidden_act] - self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1", data_format="channels_first") - self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2", data_format="channels_first") - self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3", data_format="channels_first") + self.conv1 = tf.keras.layers.Conv2D( + self.mask_input_channels, kernel_size=2, strides=2, name="conv1", data_format="channels_first" + ) + self.conv2 = tf.keras.layers.Conv2D( + self.mask_input_channels, kernel_size=2, strides=2, name="conv2", data_format="channels_first" + ) + self.conv3 = tf.keras.layers.Conv2D( + config.hidden_size, kernel_size=1, name="conv3", data_format="channels_first" + ) self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") @@ -631,12 +658,15 @@ def build(self, input_shape): initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), trainable=True, ) - self.point_embed = [self.add_weight( - name=f"point_embed_._{i}.weight", - shape=(1, self.hidden_size), - initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), - trainable=True, - ) for i in range(self.config.num_point_embeddings)] + self.point_embed = [ + self.add_weight( + name=f"point_embed_._{i}.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + for i in range(self.config.num_point_embeddings) + ] self.not_a_point_embed = self.add_weight( name="not_a_point_embed.weight", shape=(1, self.hidden_size), @@ -665,8 +695,12 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T point_embedding, tf.zeros_like(point_embedding), ) - point_embedding = tf.where((labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding) - point_embedding = tf.where((labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding) + point_embedding = tf.where( + (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding + ) + point_embedding = tf.where( + (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding + ) return point_embedding def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: @@ -745,7 +779,6 @@ def __init__(self, config, window_size, **kwargs) -> None: ) self.input_size = input_size - self.num_attention_heads = config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads self.head_dim = head_dim @@ -760,10 +793,9 @@ def __init__(self, config, window_size, **kwargs) -> None: if input_size is None: raise ValueError("Input size must be provided if using relative positional encoding.") - def build(self, input_shape): if self.input_size is not None: - # initialize relative positional embeddings + # initialize relative positional embeddings self.rel_pos_h = self.add_weight( shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" ) @@ -967,10 +999,17 @@ def __init__(self, config: SamVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config - self.conv1 = tf.keras.layers.Conv2D(config.output_channels, kernel_size=1, use_bias=False, name="conv1", data_format="channels_first") + self.conv1 = tf.keras.layers.Conv2D( + config.output_channels, kernel_size=1, use_bias=False, name="conv1", data_format="channels_first" + ) self.layer_norm1 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm1") self.conv2 = tf.keras.layers.Conv2D( - config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2", data_format="channels_first" + config.output_channels, + kernel_size=3, + padding="same", + use_bias=False, + name="conv2", + data_format="channels_first", ) self.layer_norm2 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm2") @@ -1092,11 +1131,16 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: `Dict[str, tf.Tensor]`: The dummy inputs. """ VISION_DUMMY_INPUTS = tf.random.uniform( - shape=(3, self.config.vision_config.num_channels, self.config.vision_config.image_size, self.config.vision_config.image_size), dtype=tf.float32 + shape=( + 3, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ), + dtype=tf.float32, ) return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} - @tf.function( input_signature=[ { @@ -1210,7 +1254,9 @@ def __init__(self, config, **kwargs) -> None: self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") - self.prompt_encoder = TFSamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder") + self.prompt_encoder = TFSamPromptEncoder( + config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" + ) self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") def get_input_embeddings(self): diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index f43ba67f4fcb..70991fd1e567 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -23,13 +23,10 @@ from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig from transformers.testing_utils import require_tf, slow -from transformers.utils import is_tf_available, is_vision_available, is_torch_available - -import tempfile -import os +from transformers.utils import is_tf_available, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester -from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, is_pt_tf_cross_test +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -40,7 +37,7 @@ from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST if is_torch_available(): - import torch + pass if is_vision_available(): @@ -658,11 +655,13 @@ def test_inference_mask_generation_two_points_point_batch(self): iou_scores = outputs.iou_scores self.assertTrue(iou_scores.shape == (1, 2, 3)) self.assertTrue( - np.allclose(iou_scores.numpy(), - np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), - atol=1e-4, - rtol=1e-4, - )) + np.allclose( + iou_scores.numpy(), + np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), + atol=1e-4, + rtol=1e-4, + ) + ) def test_inference_mask_generation_three_boxes_point_batch(self): model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) From 9c3066b8a54138d0294bc8c5b46e5dfd891325d0 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 2 May 2023 15:13:13 +0100 Subject: [PATCH 10/49] Make convolutions channels-last so they run on CPU --- .../models/sam/modeling_tf_sam.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 7c3402a3db3e..7257a3f66415 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -131,7 +131,7 @@ def __init__(self, config, **kwargs): self.num_patches = num_patches self.projection = tf.keras.layers.Conv2D( - hidden_size, kernel_size=patch_size, strides=patch_size, name="projection", data_format="channels_first" + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" ) def call(self, pixel_values): @@ -144,7 +144,7 @@ def call(self, pixel_values): raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) - embeddings = tf.transpose(self.projection(pixel_values), perm=[0, 2, 3, 1]) + embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) return embeddings @@ -613,18 +613,19 @@ def __init__(self, config: SamPromptEncoderConfig, **kwargs): self.mask_input_channels = config.mask_input_channels // 4 self.activation = ACT2FN[config.hidden_act] self.conv1 = tf.keras.layers.Conv2D( - self.mask_input_channels, kernel_size=2, strides=2, name="conv1", data_format="channels_first" + self.mask_input_channels, kernel_size=2, strides=2, name="conv1" ) self.conv2 = tf.keras.layers.Conv2D( - self.mask_input_channels, kernel_size=2, strides=2, name="conv2", data_format="channels_first" + self.mask_input_channels, kernel_size=2, strides=2, name="conv2" ) self.conv3 = tf.keras.layers.Conv2D( - config.hidden_size, kernel_size=1, name="conv3", data_format="channels_first" + config.hidden_size, kernel_size=1, name="conv3" ) self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") def call(self, masks): + masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last hidden_states = self.conv1(masks) hidden_states = self.layer_norm1(hidden_states) hidden_states = self.activation(hidden_states) @@ -633,6 +634,7 @@ def call(self, masks): hidden_states = self.layer_norm2(hidden_states) hidden_states = self.activation(hidden_states) dense_embeddings = self.conv3(hidden_states) + masks = tf.transpose(masks, perm=(0, 3, 1, 2)) # Convert back to channels-first return dense_embeddings @@ -1000,26 +1002,25 @@ def __init__(self, config: SamVisionConfig, **kwargs): self.config = config self.conv1 = tf.keras.layers.Conv2D( - config.output_channels, kernel_size=1, use_bias=False, name="conv1", data_format="channels_first" + config.output_channels, kernel_size=1, use_bias=False, name="conv1", ) - self.layer_norm1 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm1") + self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") self.conv2 = tf.keras.layers.Conv2D( config.output_channels, kernel_size=3, padding="same", use_bias=False, name="conv2", - data_format="channels_first", ) - self.layer_norm2 = TFSamLayerNorm(config.output_channels, data_format="channels_first", name="layer_norm2") + self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") def call(self, hidden_states): - hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.layer_norm2(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) return hidden_states From 6fe6674133de90ecbd7aab6ae95fb1b1bc27e149 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 2 May 2023 15:27:13 +0100 Subject: [PATCH 11/49] make fixup --- src/transformers/models/sam/modeling_tf_sam.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 7257a3f66415..28d14d4c2b8d 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -612,15 +612,9 @@ def __init__(self, config: SamPromptEncoderConfig, **kwargs): super().__init__(**kwargs) self.mask_input_channels = config.mask_input_channels // 4 self.activation = ACT2FN[config.hidden_act] - self.conv1 = tf.keras.layers.Conv2D( - self.mask_input_channels, kernel_size=2, strides=2, name="conv1" - ) - self.conv2 = tf.keras.layers.Conv2D( - self.mask_input_channels, kernel_size=2, strides=2, name="conv2" - ) - self.conv3 = tf.keras.layers.Conv2D( - config.hidden_size, kernel_size=1, name="conv3" - ) + self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") + self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") @@ -1002,7 +996,10 @@ def __init__(self, config: SamVisionConfig, **kwargs): self.config = config self.conv1 = tf.keras.layers.Conv2D( - config.output_channels, kernel_size=1, use_bias=False, name="conv1", + config.output_channels, + kernel_size=1, + use_bias=False, + name="conv1", ) self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") self.conv2 = tf.keras.layers.Conv2D( From d6dec9aaf68e58b72f6b9a23d31310e583009985 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 2 May 2023 16:38:46 +0100 Subject: [PATCH 12/49] Fix final issues --- .../models/sam/modeling_tf_sam.py | 32 +++++++++++++++++-- tests/models/sam/test_modeling_tf_sam.py | 8 ++--- tests/test_modeling_tf_common.py | 14 ++++---- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 28d14d4c2b8d..942b9844ef1f 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -613,7 +613,7 @@ def __init__(self, config: SamPromptEncoderConfig, **kwargs): self.mask_input_channels = config.mask_input_channels // 4 self.activation = ACT2FN[config.hidden_act] self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") - self.conv2 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv2 = tf.keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") @@ -628,9 +628,28 @@ def call(self, masks): hidden_states = self.layer_norm2(hidden_states) hidden_states = self.activation(hidden_states) dense_embeddings = self.conv3(hidden_states) - masks = tf.transpose(masks, perm=(0, 3, 1, 2)) # Convert back to channels-first + dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first return dense_embeddings + def build(self, input_shape): + # This class needs an explicit build method because it isn't called with the standard dummy inputs + conv1_shape = [None, None, None, 1] + conv2_shape = [None, None, None, self.mask_input_channels] + conv3_shape = [None, None, None, self.mask_input_channels * 4] + layer_norm1_shape = [None, None, None, self.mask_input_channels] + layer_norm2_shape = [None, None, None, self.mask_input_channels * 4] + with tf.name_scope("conv1"): + self.conv1.build(conv1_shape) + with tf.name_scope("conv2"): + self.conv2.build(conv2_shape) + with tf.name_scope("conv3"): + self.conv3.build(conv3_shape) + with tf.name_scope("layer_norm1"): + self.layer_norm1.build(layer_norm1_shape) + with tf.name_scope("layer_norm2"): + self.layer_norm2.build(layer_norm2_shape) + super().build(input_shape) + class TFSamPromptEncoder(tf.keras.layers.Layer): def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): @@ -669,6 +688,11 @@ def build(self, input_shape): initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), trainable=True, ) + with tf.name_scope("mask_embed"): + # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs + self.mask_embed.build( + (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) + ) super().build(input_shape) def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: @@ -1256,6 +1280,7 @@ def __init__(self, config, **kwargs) -> None: config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" ) self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") + self.config = config def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() @@ -1380,6 +1405,9 @@ def call( point_batch_size, box_batch_size ) ) + if pixel_values is not None: + # Ensures that later checks pass even with an all-None shape from the serving signature + pixel_values = tf.ensure_shape(pixel_values, [None, self.config.vision_config.num_channels, None, None]) image_positional_embeddings = self.get_image_wide_positional_embeddings() # repeat with batch size batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 70991fd1e567..4be6d3c9bac0 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -428,12 +428,8 @@ def test_model_from_pretrained(self): model = TFSamModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) - def test_pt_tf_model_equivalence(self, allow_missing_keys=True): - super().test_pt_tf_model_equivalence(allow_missing_keys=True) - - @unittest.skip(reason="Temporary skip while we resolve other issues - do not merge until this is removed!") - def test_saved_model_creation(self): - pass + def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): + super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) def prepare_image(): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 220560d9238a..277738b39026 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -640,7 +640,7 @@ def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): return pt_inputs_dict - def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) # send pytorch inputs to the correct device @@ -665,10 +665,10 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): if tf_loss is not None: tf_outputs.loss = tf.math.reduce_mean(tf_loss) - self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model)) + self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model), tol=tol) @is_pt_tf_cross_test - def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + def test_pt_tf_model_equivalence(self, allow_missing_keys=False, tol=1e-5): import transformers for model_class in self.all_model_classes: @@ -711,10 +711,10 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): ) # Original test: check without `labels` - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) # check with `labels` if tf_inputs_dict_with_labels: - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels, tol=tol) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: @@ -731,10 +731,10 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): ) # Original test: check without `labels` - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) # check with `labels` if tf_inputs_dict_with_labels: - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels, tol=tol) def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From dba0920f2cdea9c0d759344c2cb7c1aa3ea63442 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 2 May 2023 17:51:23 +0100 Subject: [PATCH 13/49] Fix other models affected by test change --- tests/models/data2vec/test_modeling_tf_data2vec_vision.py | 3 +++ tests/models/sam/test_modeling_sam.py | 3 +++ tests/models/vit_mae/test_modeling_tf_vit_mae.py | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py index dfa890d25a9e..6d4ca2eea083 100644 --- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py @@ -459,6 +459,9 @@ def test_model_from_pretrained(self): model = TFData2VecVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) + def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): + super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index e51eb07dd311..a9f7c204fccd 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -436,6 +436,9 @@ def test_retain_grad_hidden_states_attentions(self): def test_hidden_states_output(self): pass + def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): + super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) + @slow def test_model_from_pretrained(self): for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 53d68b644ac8..1449fc0b7597 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -267,7 +267,7 @@ def prepare_numpy_arrays(inputs_dict): # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test - def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): # make masks reproducible np.random.seed(2) @@ -279,7 +279,7 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): # PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument tf_inputs_dict["noise"] = tf_noise - super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) # overwrite from common since TFViTMAEForPretraining outputs loss along with # logits and mask indices. loss and mask indices are not suitable for integration From 989dd3fbd38fb2c4d55ba1ddca859b8a6cd8c422 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 14:05:06 +0100 Subject: [PATCH 14/49] Clarify comment on the sparse_prompt_embeddings check --- src/transformers/models/sam/modeling_tf_sam.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 942b9844ef1f..b648ce6b87f2 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -512,8 +512,9 @@ def call( output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] ) # Should be (batch_size, point_size, 5, 32) - # Matt: I think this sum is actually checking that the sparse prompt embeddings aren't an empty tensor - # with shape[1] == 0, so I'm going to replace this + # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only + # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced + # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. if sparse_prompt_embeddings.shape[1] != 0: tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) else: From f842d4337b8f4324a98c8c81b24fbd80463697dd Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 15:32:44 +0100 Subject: [PATCH 15/49] Refactor functional_layernorm, use shape_list in place of .shape in some places --- .../models/sam/modeling_tf_sam.py | 37 ++++++++++++------- src/transformers/tf_utils.py | 18 +++++++-- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index b648ce6b87f2..693c91450b69 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -183,16 +183,17 @@ def build(self, input_shape): def call(self, x: tf.Tensor) -> tf.Tensor: if self.data_format == "channels_last": - x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps) + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) elif self.data_format == "channels_first": - input_dtype = x.dtype - x = tf.cast(x, tf.float32) - u = tf.reduce_mean(x, axis=1, keepdims=True) - s = tf.math.square(x - u) - s = tf.reduce_mean(s, axis=1, keepdims=True) - x = (x - u) / tf.math.sqrt(s + self.eps) - x = tf.cast(x, input_dtype) - x = self.weight[:, None, None] * x + self.bias[:, None, None] + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) + # input_dtype = x.dtype + # x = tf.cast(x, tf.float32) + # u = tf.reduce_mean(x, axis=1, keepdims=True) + # s = tf.math.square(x - u) + # s = tf.reduce_mean(s, axis=1, keepdims=True) + # x = (x - u) / tf.math.sqrt(s + self.eps) + # x = tf.cast(x, input_dtype) + # x = self.weight[:, None, None] * x + self.bias[:, None, None] return x @@ -897,7 +898,7 @@ def add_decomposed_rel_pos( relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) - batch_size, _, dim = query.shape + batch_size, _, dim = shape_list(query) reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) @@ -907,7 +908,7 @@ def add_decomposed_rel_pos( return attn def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: - batch_size, height, width, _ = hidden_states.shape + batch_size, height, width, _ = shape_list(hidden_states) # qkv with shape (3, batch_size, nHead, height * width, channel) qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) @@ -973,7 +974,7 @@ def window_unpartition( ) -> tf.Tensor: pad_height, pad_width = padding_shape height, width = original_shape - batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) hidden_states = tf.reshape( windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] ) @@ -1408,7 +1409,15 @@ def call( ) if pixel_values is not None: # Ensures that later checks pass even with an all-None shape from the serving signature - pixel_values = tf.ensure_shape(pixel_values, [None, self.config.vision_config.num_channels, None, None]) + pixel_values = tf.ensure_shape( + pixel_values, + [ + None, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + ) image_positional_embeddings = self.get_image_wide_positional_embeddings() # repeat with batch size batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] @@ -1444,7 +1453,7 @@ def call( ) sparse_embeddings, dense_embeddings = self.prompt_encoder( - batch_size=image_embeddings.shape[0], + batch_size=shape_list(image_embeddings)[0], input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 60848c65b4f2..7aa64e3743d5 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -70,14 +70,24 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) -def functional_layernorm(inputs, weight, bias, epsilon=1e-5): +def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): # This is a very simplified functional layernorm, designed to duplicate # the functionality of PyTorch nn.functional.layer_norm when this is needed to port - # models in Transformers. It assumes the dimension to be normalized is always the last one. - # If you need it to handle multiple dimensions, yell at me (Matt) and I'll patch it. + # models in Transformers. + + if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): + raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") # Calculate the moments on the last axis (layer activations). - mean, variance = tf.nn.moments(inputs, -1, keepdims=True) + mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) + + if axis != -1: + # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions + # on every dimension except axis + shape = [1] * inputs.shape.rank + shape[axis] = shape_list(inputs)[axis] + weight = tf.reshape(weight, shape) + bias = tf.reshape(bias, shape) # Compute layer normalization using the batch_normalization # function. From d6653e795ce094be09762cc32f8e895bc06db0ea Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 15:33:29 +0100 Subject: [PATCH 16/49] Remove deprecated torch-alike code --- src/transformers/models/sam/modeling_tf_sam.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 693c91450b69..93baf82a245a 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -186,14 +186,6 @@ def call(self, x: tf.Tensor) -> tf.Tensor: x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) elif self.data_format == "channels_first": x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) - # input_dtype = x.dtype - # x = tf.cast(x, tf.float32) - # u = tf.reduce_mean(x, axis=1, keepdims=True) - # s = tf.math.square(x - u) - # s = tf.reduce_mean(s, axis=1, keepdims=True) - # x = (x - u) / tf.math.sqrt(s + self.eps) - # x = tf.cast(x, input_dtype) - # x = self.weight[:, None, None] * x + self.bias[:, None, None] return x From e872394d58828838d82c2ab62ae4df74fadd05a0 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 15:50:50 +0100 Subject: [PATCH 17/49] Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/models/sam/test_modeling_tf_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 4be6d3c9bac0..976e28906a33 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch SAM model. """ +""" Testing suite for the TensorFlow SAM model. """ import inspect From 25197cfc01751390a542150962633d9aea64cc51 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 15:52:23 +0100 Subject: [PATCH 18/49] Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/models/sam/test_modeling_tf_sam.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 976e28906a33..8efad29f6869 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -36,10 +36,6 @@ from transformers import SamProcessor, TFSamModel from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST -if is_torch_available(): - pass - - if is_vision_available(): from PIL import Image From 5aabb169d083dea05a01e4a2fe5ad5ecbe4dbeac Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 16:59:55 +0100 Subject: [PATCH 19/49] Refactor processor with common methods and separated private methods --- .../models/sam/image_processing_sam.py | 187 ++++++++++++++---- 1 file changed, 149 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index c27e84d0078c..5d92356879e3 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -385,6 +385,56 @@ def preprocess( return encoded_outputs def post_process_masks( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None, return_tensors="pt" + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + return_tensors (`str`, *optional*, defaults to `"pt"`): + If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. + Returns: + (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + if return_tensors == "pt": + return self._post_process_masks_pt( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + elif return_tensors == "tf": + return self._post_process_masks_tf( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'") + + + + def _post_process_masks_pt( self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None ): """ @@ -431,7 +481,7 @@ def post_process_masks( return output_masks - def post_process_masks_tf( + def _post_process_masks_tf( self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None ): """ @@ -473,7 +523,28 @@ def post_process_masks_tf( return output_masks - def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all predicted segmentation masks + all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all predicted iou scores + all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return self._post_process_for_mask_generation_pt(all_masks, all_scores, all_boxes, crops_nms_thresh) + elif return_tensors == "tf": + return self._post_process_for_mask_generation_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def _post_process_for_mask_generation_pt(self, all_masks, all_scores, all_boxes, crops_nms_thresh): """ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. @@ -489,7 +560,7 @@ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, cro """ return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) - def post_process_for_mask_generation_tf(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + def _post_process_for_mask_generation_tf(self, all_masks, all_scores, all_boxes, crops_nms_thresh): """ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. @@ -514,6 +585,7 @@ def generate_crop_boxes( points_per_crop: Optional[int] = 32, crop_n_points_downscale_factor: Optional[List[int]] = 1, device: Optional["torch.device"] = None, + return_tensors: str = "pt", ): """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. @@ -535,44 +607,85 @@ def generate_crop_boxes( The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. device (`torch.device`, *optional*, defaults to None): Device to use for the computation. If None, cpu will be used. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ - return _generate_crop_boxes( - image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device - ) + if return_tensors == "pt": + return _generate_crop_boxes_pt( + image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device + ) + elif return_tensors == "tf": + if device is not None: + raise ValueError("device is not a supported argument when return_tensors is tf!") + return _generate_crop_boxes_tf( + image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor + ) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'.") - def generate_crop_boxes_tf( + def filter_masks( self, - image, - target_size, - crop_n_layers: int = 0, - overlap_ratio: float = 512 / 1500, - points_per_crop: Optional[int] = 32, - crop_n_points_downscale_factor: Optional[List[int]] = 1, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + return_tensors="pt", ): """ - Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. Args: - image (`np.array`): - Input original image - target_size (`int`): - Target size of the resized image - crop_n_layers (`int`, *optional*, defaults to 0): - If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where - each layer has 2**i_layer number of image crops. - overlap_ratio (`float`, *optional*, defaults to 512/1500): - Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of - the image length. Later layers with more crops scale down this overlap. - points_per_crop (`int`, *optional*, defaults to 32): - Number of points to sample from each crop. - crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1): - The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + masks (`Union[torch.Tensor, tf.Tensor]`): + Input masks. + iou_scores (`Union[torch.Tensor, tf.Tensor]`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ - return _generate_crop_boxes_tf( - image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor - ) + if return_tensors == "pt": + return self._filter_masks_pt( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + elif return_tensors == "tf": + return self._filter_masks_tf( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) - def filter_masks( + + def _filter_masks_pt( self, masks, iou_scores, @@ -628,7 +741,7 @@ def filter_masks( # compute stability score if stability_score_thresh > 0.0: - stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset) keep_mask = keep_mask & (stability_scores > stability_score_thresh) scores = iou_scores[keep_mask] @@ -652,7 +765,7 @@ def filter_masks( return masks, scores, converted_boxes - def filter_masks_tf( + def _filter_masks_tf( self, masks, iou_scores, @@ -730,7 +843,7 @@ def filter_masks_tf( return masks, scores, converted_boxes -def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): +def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): # One mask is always contained inside the other. # Save memory by preventing unnecesary cast to torch.int64 intersections = ( @@ -790,7 +903,7 @@ def _normalize_coordinates( return coords -def _generate_crop_boxes( +def _generate_crop_boxes_pt( image, target_size: int, # Is it tuple here? crop_n_layers: int = 0, @@ -1221,8 +1334,6 @@ def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thre amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): NMS threshold. """ - breakpoint() - print() # Need to check the input shapes here so I know where to pad them keep_by_nms = tf.image.combined_non_max_suppression( boxes=mask_boxes.float(), scores=iou_scores, From d5e1fee3eba1b416de07f87093639cd438c1ca99 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 17:00:21 +0100 Subject: [PATCH 20/49] make fixup --- .../models/sam/image_processing_sam.py | 28 +++++++++++++------ tests/models/sam/test_modeling_tf_sam.py | 2 +- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 5d92356879e3..f710d68f1397 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -385,7 +385,14 @@ def preprocess( return encoded_outputs def post_process_masks( - self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None, return_tensors="pt" + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + return_tensors="pt", ): """ Remove padding and upscale masks to the original image size. @@ -408,8 +415,8 @@ def post_process_masks( return_tensors (`str`, *optional*, defaults to `"pt"`): If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. Returns: - (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) - is given by original_size. + (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where + (height, width) is given by original_size. """ if return_tensors == "pt": return self._post_process_masks_pt( @@ -432,8 +439,6 @@ def post_process_masks( else: raise ValueError("return_tensors must be either 'pt' or 'tf'") - - def _post_process_masks_pt( self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None ): @@ -523,7 +528,9 @@ def _post_process_masks_tf( return output_masks - def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"): + def post_process_for_mask_generation( + self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt" + ): """ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. @@ -612,7 +619,13 @@ def generate_crop_boxes( """ if return_tensors == "pt": return _generate_crop_boxes_pt( - image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + device, ) elif return_tensors == "tf": if device is not None: @@ -684,7 +697,6 @@ def filter_masks( stability_score_offset=stability_score_offset, ) - def _filter_masks_pt( self, masks, diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 8efad29f6869..688d74d063f5 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -23,7 +23,7 @@ from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig from transformers.testing_utils import require_tf, slow -from transformers.utils import is_tf_available, is_torch_available, is_vision_available +from transformers.utils import is_tf_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor From 4650fcfd0713b3502b590e029a2288bd648c88f8 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 17:03:03 +0100 Subject: [PATCH 21/49] Quietly delete the file that didn't do anything (sorry Sylvain) --- .../models/sam/processing_tf_sam.py | 248 ------------------ 1 file changed, 248 deletions(-) delete mode 100644 src/transformers/models/sam/processing_tf_sam.py diff --git a/src/transformers/models/sam/processing_tf_sam.py b/src/transformers/models/sam/processing_tf_sam.py deleted file mode 100644 index 83356f41ee16..000000000000 --- a/src/transformers/models/sam/processing_tf_sam.py +++ /dev/null @@ -1,248 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Processor class for SAM. -""" -from copy import deepcopy -from typing import Optional, Union - -import numpy as np - -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_torch_available - - -if is_torch_available(): - import torch - - -class SamProcessor(ProcessorMixin): - r""" - Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a - single processor. - - [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of - [`~SamImageProcessor.__call__`] for more information. - - Args: - image_processor (`SamImageProcessor`): - An instance of [`SamImageProcessor`]. The image processor is a required input. - """ - attributes = ["image_processor"] - image_processor_class = "SamImageProcessor" - - def __init__(self, image_processor): - super().__init__(image_processor) - self.current_processor = self.image_processor - self.point_pad_value = -10 - self.target_size = self.image_processor.size["longest_edge"] - - def __call__( - self, - images=None, - input_points=None, - input_labels=None, - input_boxes=None, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: - """ - This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D - points and bounding boxes for the model if they are provided. - """ - encoding_image_processor = self.image_processor( - images, - return_tensors=return_tensors, - **kwargs, - ) - - # pop arguments that are not used in the foward but used nevertheless - original_sizes = encoding_image_processor["original_sizes"] - - if isinstance(original_sizes, torch.Tensor): - original_sizes = original_sizes.numpy() - - input_points, input_labels, input_boxes = self._check_and_preprocess_points( - input_points=input_points, - input_labels=input_labels, - input_boxes=input_boxes, - ) - - encoding_image_processor = self._normalize_and_convert( - encoding_image_processor, - original_sizes, - input_points=input_points, - input_labels=input_labels, - input_boxes=input_boxes, - return_tensors=return_tensors, - ) - - return encoding_image_processor - - def _normalize_and_convert( - self, - encoding_image_processor, - original_sizes, - input_points=None, - input_labels=None, - input_boxes=None, - return_tensors="pt", - ): - if input_points is not None: - if len(original_sizes) != len(input_points): - input_points = [ - self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points - ] - else: - input_points = [ - self._normalize_coordinates(self.target_size, point, original_size) - for point, original_size in zip(input_points, original_sizes) - ] - # check that all arrays have the same shape - if not all([point.shape == input_points[0].shape for point in input_points]): - if input_labels is not None: - input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) - - input_points = np.array(input_points) - - if input_labels is not None: - input_labels = np.array(input_labels) - - if input_boxes is not None: - if len(original_sizes) != len(input_boxes): - input_boxes = [ - self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) - for box in input_boxes - ] - else: - input_boxes = [ - self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) - for box, original_size in zip(input_boxes, original_sizes) - ] - input_boxes = np.array(input_boxes) - - if input_boxes is not None: - if return_tensors == "pt": - input_boxes = torch.from_numpy(input_boxes) - # boxes batch size of 1 by default - input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes - encoding_image_processor.update({"input_boxes": input_boxes}) - if input_points is not None: - if return_tensors == "pt": - input_points = torch.from_numpy(input_points) - # point batch size of 1 by default - input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points - encoding_image_processor.update({"input_points": input_points}) - if input_labels is not None: - if return_tensors == "pt": - input_labels = torch.from_numpy(input_labels) - # point batch size of 1 by default - input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels - encoding_image_processor.update({"input_labels": input_labels}) - - return encoding_image_processor - - def _pad_points_and_labels(self, input_points, input_labels): - r""" - The method pads the 2D points and labels to the maximum number of points in the batch. - """ - expected_nb_points = max([point.shape[0] for point in input_points]) - processed_input_points = [] - for i, point in enumerate(input_points): - if point.shape[0] != expected_nb_points: - point = np.concatenate( - [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 - ) - input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) - processed_input_points.append(point) - input_points = processed_input_points - return input_points, input_labels - - def _normalize_coordinates( - self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False - ) -> np.ndarray: - """ - Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. - """ - old_h, old_w = original_size - new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) - coords = deepcopy(coords).astype(float) - - if is_bounding_box: - coords = coords.reshape(-1, 2, 2) - - coords[..., 0] = coords[..., 0] * (new_w / old_w) - coords[..., 1] = coords[..., 1] * (new_h / old_h) - - if is_bounding_box: - coords = coords.reshape(-1, 4) - - return coords - - def _check_and_preprocess_points( - self, - input_points=None, - input_labels=None, - input_boxes=None, - ): - r""" - Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they - are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, - it is converted to a `numpy.ndarray` and then to a `list`. - """ - if input_points is not None: - if isinstance(input_points, torch.Tensor): - input_points = input_points.numpy().tolist() - - if not isinstance(input_points, list) and not isinstance(input_points[0], list): - raise ValueError("Input points must be a list of list of floating integers.") - input_points = [np.array(input_point) for input_point in input_points] - else: - input_points = None - - if input_labels is not None: - if isinstance(input_labels, torch.Tensor): - input_labels = input_labels.numpy().tolist() - - if not isinstance(input_labels, list) and not isinstance(input_labels[0], list): - raise ValueError("Input labels must be a list of list integers.") - input_labels = [np.array(label) for label in input_labels] - else: - input_labels = None - - if input_boxes is not None: - if isinstance(input_boxes, torch.Tensor): - input_boxes = input_boxes.numpy().tolist() - - if ( - not isinstance(input_boxes, list) - and not isinstance(input_boxes[0], list) - and not isinstance(input_boxes[0][0], list) - ): - raise ValueError("Input boxes must be a list of list of list of floating integers.") - input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] - else: - input_boxes = None - - return input_points, input_labels, input_boxes - - @property - def model_input_names(self): - image_processor_input_names = self.image_processor.model_input_names - return list(dict.fromkeys(image_processor_input_names)) - - def post_process_masks(self, *args, **kwargs): - return self.image_processor.post_process_masks(*args, **kwargs) From b72bfc14b6420b99fa3c9e80b252201dee586542 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 17:08:56 +0100 Subject: [PATCH 22/49] Refactor the processor tests into one file --- src/transformers/models/sam/processing_sam.py | 6 +- tests/models/sam/test_processor_sam.py | 95 +++++++++++++- tests/models/sam/test_processor_tf_sam.py | 122 ------------------ 3 files changed, 94 insertions(+), 129 deletions(-) delete mode 100644 tests/models/sam/test_processor_tf_sam.py diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index a658030d36ff..1907f69eae3b 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -260,8 +260,4 @@ def model_input_names(self): return list(dict.fromkeys(image_processor_input_names)) def post_process_masks(self, *args, **kwargs): - return_tensors = kwargs.pop("return_tensors", "pt") - if return_tensors == "pt": - return self.image_processor.post_process_masks(*args, **kwargs) - elif return_tensors == "tf": - return self.image_processor.post_process_masks_tf(*args, **kwargs) + return self.image_processor.post_process_masks(*args, **kwargs) \ No newline at end of file diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 13efa22e3e3c..a35da074abf2 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -17,8 +17,8 @@ import numpy as np -from transformers.testing_utils import require_torch, require_torchvision, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_torchvision, require_vision, require_tf +from transformers.utils import is_torch_available, is_vision_available, is_tf_available if is_vision_available(): @@ -29,6 +29,9 @@ if is_torch_available(): import torch +if is_tf_available(): + import tensorflow as tf + @require_vision @require_torchvision @@ -110,3 +113,91 @@ def test_post_process_masks(self): dummy_masks = [[1, 0], [0, 1]] with self.assertRaises(ValueError): masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + +@require_vision +@require_tf +class TFSamProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_additional_features(self): + processor = SamProcessor(image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, SamImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor + input_feat_extract.pop("reshaped_input_sizes") # pop original_sizes as it is popped in the processor + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + @require_tf + def test_post_process_masks(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + dummy_masks = [tf.ones((1, 3, 5, 5))] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + masks = processor.post_process_masks( + dummy_masks, + tf.convert_to_tensor(original_sizes), + tf.convert_to_tensor(reshaped_input_size), + return_tensors="tf", + ) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + # should also work with np + dummy_masks = [np.ones((1, 3, 5, 5))] + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) + + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + dummy_masks = [[1, 0], [0, 1]] + with self.assertRaises(tf.errors.InvalidArgumentError): + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) diff --git a/tests/models/sam/test_processor_tf_sam.py b/tests/models/sam/test_processor_tf_sam.py deleted file mode 100644 index c66aa4b89a4c..000000000000 --- a/tests/models/sam/test_processor_tf_sam.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import shutil -import tempfile -import unittest - -import numpy as np - -from transformers.testing_utils import require_tf, require_vision -from transformers.utils import is_tf_available, is_torch_available, is_vision_available - - -if is_vision_available(): - from PIL import Image - - from transformers import AutoProcessor, SamImageProcessor, SamProcessor - -if is_torch_available(): - pass - -if is_tf_available(): - import tensorflow as tf - - -@require_vision -@require_tf -class TFSamProcessorTest(unittest.TestCase): - def setUp(self): - self.tmpdirname = tempfile.mkdtemp() - image_processor = SamImageProcessor() - processor = SamProcessor(image_processor) - processor.save_pretrained(self.tmpdirname) - - def get_image_processor(self, **kwargs): - return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor - - def tearDown(self): - shutil.rmtree(self.tmpdirname) - - def prepare_image_inputs(self): - """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, - or a list of PyTorch tensors if one specifies torchify=True. - """ - - image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] - - image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] - - return image_inputs - - def test_save_load_pretrained_additional_features(self): - processor = SamProcessor(image_processor=self.get_image_processor()) - processor.save_pretrained(self.tmpdirname) - - image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) - - processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0) - - self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) - self.assertIsInstance(processor.image_processor, SamImageProcessor) - - def test_image_processor(self): - image_processor = self.get_image_processor() - - processor = SamProcessor(image_processor=image_processor) - - image_input = self.prepare_image_inputs() - - input_feat_extract = image_processor(image_input, return_tensors="np") - input_processor = processor(images=image_input, return_tensors="np") - - input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor - input_feat_extract.pop("reshaped_input_sizes") # pop original_sizes as it is popped in the processor - - for key in input_feat_extract.keys(): - self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) - - @require_tf - def test_post_process_masks(self): - image_processor = self.get_image_processor() - - processor = SamProcessor(image_processor=image_processor) - dummy_masks = [tf.ones((1, 3, 5, 5))] - - original_sizes = [[1764, 2646]] - - reshaped_input_size = [[683, 1024]] - masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") - self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) - - masks = processor.post_process_masks( - dummy_masks, - tf.convert_to_tensor(original_sizes), - tf.convert_to_tensor(reshaped_input_size), - return_tensors="tf", - ) - self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) - - # should also work with np - dummy_masks = [np.ones((1, 3, 5, 5))] - masks = processor.post_process_masks( - dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" - ) - - self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) - - dummy_masks = [[1, 0], [0, 1]] - with self.assertRaises(tf.errors.InvalidArgumentError): - masks = processor.post_process_masks( - dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" - ) From fc5136a53b401acb05c9b5613d906ba06490f98d Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 May 2023 17:15:33 +0100 Subject: [PATCH 23/49] make fixup --- src/transformers/models/sam/processing_sam.py | 2 +- tests/models/sam/test_processor_sam.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 1907f69eae3b..d0b3caf0b239 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -260,4 +260,4 @@ def model_input_names(self): return list(dict.fromkeys(image_processor_input_names)) def post_process_masks(self, *args, **kwargs): - return self.image_processor.post_process_masks(*args, **kwargs) \ No newline at end of file + return self.image_processor.post_process_masks(*args, **kwargs) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index a35da074abf2..912e43eb3a88 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -17,8 +17,8 @@ import numpy as np -from transformers.testing_utils import require_torch, require_torchvision, require_vision, require_tf -from transformers.utils import is_torch_available, is_vision_available, is_tf_available +from transformers.testing_utils import require_tf, require_torch, require_torchvision, require_vision +from transformers.utils import is_tf_available, is_torch_available, is_vision_available if is_vision_available(): @@ -114,6 +114,7 @@ def test_post_process_masks(self): with self.assertRaises(ValueError): masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + @require_vision @require_tf class TFSamProcessorTest(unittest.TestCase): From 2e5b4e5858c930b3d7927f5f7407bf8aaab7829f Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 13:36:57 +0100 Subject: [PATCH 24/49] Clean up some unnecessary indirection --- .../models/sam/image_processing_sam.py | 36 ++----------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index f710d68f1397..31d30bdff128 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -547,41 +547,9 @@ def post_process_for_mask_generation( If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ if return_tensors == "pt": - return self._post_process_for_mask_generation_pt(all_masks, all_scores, all_boxes, crops_nms_thresh) + return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) elif return_tensors == "tf": - return self._post_process_for_mask_generation_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) - - def _post_process_for_mask_generation_pt(self, all_masks, all_scores, all_boxes, crops_nms_thresh): - """ - Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. - - Args: - all_masks (`List[torch.Tensor]`): - List of all predicted segmentation masks - all_scores (`List[torch.Tensor]`): - List of all predicted iou scores - all_boxes (`List[torch.Tensor]`): - List of all bounding boxes of the predicted masks - crops_nms_thresh (`float`): - Threshold for NMS (Non Maximum Suppression) algorithm. - """ - return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) - - def _post_process_for_mask_generation_tf(self, all_masks, all_scores, all_boxes, crops_nms_thresh): - """ - Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. - - Args: - all_masks (`List[tf.Tensor]`): - List of all predicted segmentation masks - all_scores (`List[tf.Tensor]`): - List of all predicted iou scores - all_boxes (`List[tf.Tensor]`): - List of all bounding boxes of the predicted masks - crops_nms_thresh (`float`): - Threshold for NMS (Non Maximum Suppression) algorithm. - """ - return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) + return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) def generate_crop_boxes( self, From b1cfcdf5d1c99968f6a96bdf48278a7c7f077299 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 16:02:34 +0100 Subject: [PATCH 25/49] Fix TF mask postprocessing --- .../models/sam/image_processing_sam.py | 2 +- tests/models/sam/test_processor_sam.py | 46 ++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 31d30bdff128..2e5d16b160c9 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -519,7 +519,7 @@ def _post_process_masks_tf( # tf.image expects NHWC, we transpose the NCHW inputs for it mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") - interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :] interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") if binarize: interpolated_mask = interpolated_mask > mask_threshold diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 912e43eb3a88..51524101a8cc 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -17,7 +17,7 @@ import numpy as np -from transformers.testing_utils import require_tf, require_torch, require_torchvision, require_vision +from transformers.testing_utils import require_tf, require_torch, require_torchvision, require_vision, is_pt_tf_cross_test from transformers.utils import is_tf_available, is_torch_available, is_vision_available @@ -202,3 +202,47 @@ def test_post_process_masks(self): masks = processor.post_process_masks( dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" ) + + +@require_vision +@require_torchvision +class SamProcessorEquivalenceTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + @is_pt_tf_cross_test + def test_post_process_masks_equivalence(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32) + tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)] + pt_dummy_masks = [torch.tensor(dummy_masks)] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + tf_masks = processor.post_process_masks(tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") + pt_masks = processor.post_process_masks(pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt") + + self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) From f7549ba741b613c0745c777b1e54736cdf1eeb23 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 16:41:55 +0100 Subject: [PATCH 26/49] Add more processor equivalence tests --- tests/models/sam/test_processor_sam.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 51524101a8cc..621ead316f9f 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -246,3 +246,20 @@ def test_post_process_masks_equivalence(self): pt_masks = processor.post_process_masks(pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt") self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) + + def test_image_processor_equivalence(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy() + pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy() + + tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy() + tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy() + + self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor)) + self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract)) + self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor)) \ No newline at end of file From 7945a2ddd346bc5783f7d5f72505b75883b584b9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 17:28:49 +0100 Subject: [PATCH 27/49] Refactor generate_crop_boxes to use framework-neutral np code --- .../models/sam/image_processing_sam.py | 102 +++++------------- tests/models/sam/test_processor_sam.py | 18 +++- 2 files changed, 39 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 2e5d16b160c9..64f3bae22218 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -585,24 +585,32 @@ def generate_crop_boxes( return_tensors (`str`, *optional*, defaults to `pt`): If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + ) if return_tensors == "pt": - return _generate_crop_boxes_pt( - image, - target_size, - crop_n_layers, - overlap_ratio, - points_per_crop, - crop_n_points_downscale_factor, - device, - ) + if device is None: + device = torch.device("cpu") + crop_boxes = torch.tensor(crop_boxes, device=device) + points_per_crop = torch.tensor(points_per_crop, device=device) + # cropped_images stays as np + input_labels = torch.tensor(input_labels, device=device) + elif return_tensors == "tf": if device is not None: raise ValueError("device is not a supported argument when return_tensors is tf!") - return _generate_crop_boxes_tf( - image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor - ) + crop_boxes = tf.convert_to_tensor(crop_boxes) + points_per_crop = tf.convert_to_tensor(points_per_crop) + # cropped_images stays as np + input_labels = tf.convert_to_tensor(input_labels) else: raise ValueError("return_tensors must be either 'pt' or 'tf'.") + return crop_boxes, points_per_crop, cropped_images, input_labels def filter_masks( self, @@ -883,14 +891,13 @@ def _normalize_coordinates( return coords -def _generate_crop_boxes_pt( +def _generate_crop_boxes( image, target_size: int, # Is it tuple here? crop_n_layers: int = 0, overlap_ratio: float = 512 / 1500, points_per_crop: Optional[int] = 32, crop_n_points_downscale_factor: Optional[List[int]] = 1, - device: Optional["torch.device"] = None, ) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. @@ -910,64 +917,6 @@ def _generate_crop_boxes_pt( Number of points to sample per crop. crop_n_points_downscale_factor (`int`, *optional*): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - device (`torch.device`, *optional*): - Device to run the crop generation on. Defaults to CPU. - """ - if device is None: - device = torch.device("cpu") - - if isinstance(image, list): - raise ValueError("Only one image is allowed for crop generation.") - image = to_numpy_array(image) - original_size = get_image_size(image) - - points_grid = [] - for i in range(crop_n_layers + 1): - n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) - points_grid.append(_build_point_grid(n_points)) - - crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) - - cropped_images, point_grid_per_crop = _generate_crop_images( - crop_boxes, image, points_grid, layer_idxs, target_size, original_size - ) - - crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device) - point_grid_per_crop = np.array([point_grid_per_crop]) - points_per_crop = torch.tensor(point_grid_per_crop, device=device) - points_per_crop = points_per_crop.permute(0, 2, 1, 3) - - input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device) - - return crop_boxes, points_per_crop, cropped_images, input_labels - - -def _generate_crop_boxes_tf( - image, - target_size: int, # Is it tuple here? - crop_n_layers: int = 0, - overlap_ratio: float = 512 / 1500, - points_per_crop: Optional[int] = 32, - crop_n_points_downscale_factor: Optional[List[int]] = 1, -) -> Tuple[List[List[int]], List[int]]: - """ - Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. - - Args: - image (Union[`numpy.ndarray`, `PIL.Image`, `tf.Tensor`]): - Image to generate crops for. - target_size (`int`): - Size of the smallest crop. - crop_n_layers (`int`, *optional*): - If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers - to run, where each layer has 2**i_layer number of image crops. - overlap_ratio (`int`, *optional*): - Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the - image length. Later layers with more crops scale down this overlap. - points_per_crop (`int`, *optional*): - Number of points to sample per crop. - crop_n_points_downscale_factor (`int`, *optional*): - The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. """ if isinstance(image, list): @@ -986,12 +935,11 @@ def _generate_crop_boxes_tf( crop_boxes, image, points_grid, layer_idxs, target_size, original_size ) - crop_boxes = tf.convert_to_tensor(crop_boxes, dtype=tf.float32) - point_grid_per_crop = np.array([point_grid_per_crop]) - points_per_crop = tf.convert_to_tensor(point_grid_per_crop) - points_per_crop = tf.transpose(points_per_crop, perm=(0, 2, 1, 3)) + crop_boxes = crop_boxes.astype(np.float32) + points_per_crop = np.array([point_grid_per_crop]) + points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) - input_labels = tf.ones_like(points_per_crop[:, :, :, 0], dtype=tf.int64) + input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64) return crop_boxes, points_per_crop, cropped_images, input_labels diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 621ead316f9f..a837ebaea459 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -17,7 +17,13 @@ import numpy as np -from transformers.testing_utils import require_tf, require_torch, require_torchvision, require_vision, is_pt_tf_cross_test +from transformers.testing_utils import ( + is_pt_tf_cross_test, + require_tf, + require_torch, + require_torchvision, + require_vision, +) from transformers.utils import is_tf_available, is_torch_available, is_vision_available @@ -242,8 +248,12 @@ def test_post_process_masks_equivalence(self): original_sizes = [[1764, 2646]] reshaped_input_size = [[683, 1024]] - tf_masks = processor.post_process_masks(tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") - pt_masks = processor.post_process_masks(pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt") + tf_masks = processor.post_process_masks( + tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf" + ) + pt_masks = processor.post_process_masks( + pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt" + ) self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) @@ -262,4 +272,4 @@ def test_image_processor_equivalence(self): self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor)) self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract)) - self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor)) \ No newline at end of file + self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor)) From 8305ab4ac71c35b4f7afe8319a63bf85ff21cc30 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 18:10:41 +0100 Subject: [PATCH 28/49] Make the serving output correctly conditional --- src/transformers/models/sam/modeling_tf_sam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 93baf82a245a..1d928a0c6ea3 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -1485,7 +1485,7 @@ def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegm return TFSamImageSegmentationOutput( iou_scores=output.iou_scores, pred_masks=output.pred_masks, - vision_hidden_states=hs, - vision_attentions=attns, - mask_decoder_attentions=output.mask_decoder_attentions, + vision_hidden_states=hs if self.config.output_hidden_states else None, + vision_attentions=attns if self.config.output_attentions else None, + mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, ) From f9de05482e3bba3bfece242042532e86ed15fc8d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 18:12:11 +0100 Subject: [PATCH 29/49] Fix error message line length --- src/transformers/models/sam/modeling_sam.py | 3 ++- src/transformers/models/sam/modeling_tf_sam.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 5346d3706773..c4aa97a0b30d 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1359,7 +1359,8 @@ def forward( "The batch size of the image embeddings and the input points must be the same. ", "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), " if you want to pass multiple points for the same image, make sure that you passed ", - " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", ) sparse_embeddings, dense_embeddings = self.prompt_encoder( diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 1d928a0c6ea3..30ddf0587303 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -1441,7 +1441,8 @@ def call( "The batch size of the image embeddings and the input points must be the same. ", "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), " if you want to pass multiple points for the same image, make sure that you passed ", - " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", ) sparse_embeddings, dense_embeddings = self.prompt_encoder( From 63d1b68cb17c7ce16460989cf4f50c3120a6e518 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 18:25:37 +0100 Subject: [PATCH 30/49] Use dict keys rather than indices internally in both TF and PT SAM call/forward --- src/transformers/models/sam/modeling_sam.py | 9 ++++----- src/transformers/models/sam/modeling_tf_sam.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index c4aa97a0b30d..5b963386b79b 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1334,7 +1334,6 @@ def forward( image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) vision_attentions = None - mask_decoder_attentions = None vision_hidden_states = None if pixel_values is not None: @@ -1342,14 +1341,14 @@ def forward( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) - image_embeddings = vision_outputs[0] + image_embeddings = vision_outputs['last_hidden_state'] if output_hidden_states: - vision_hidden_states = vision_outputs[1] + vision_hidden_states = vision_outputs['hidden_states'] if output_attentions: - vision_attentions = vision_outputs[-1] + vision_attentions = vision_outputs['attentions'] if input_points is not None and input_labels is None: input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 30ddf0587303..af484a0ccf4f 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -1416,7 +1416,6 @@ def call( image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) vision_attentions = None - mask_decoder_attentions = None vision_hidden_states = None if pixel_values is not None: @@ -1424,14 +1423,14 @@ def call( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) - image_embeddings = vision_outputs[0] + image_embeddings = vision_outputs['last_hidden_state'] if output_hidden_states: - vision_hidden_states = vision_outputs[1] + vision_hidden_states = vision_outputs['hidden_states'] if output_attentions: - vision_attentions = vision_outputs[-1] + vision_attentions = vision_outputs['attentions'] if input_points is not None and input_labels is None: input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) From faf5cb05a35228e1724334c9a87b23db905da7c9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 18:58:57 +0100 Subject: [PATCH 31/49] Return dicts internally in the call/forward methods --- src/transformers/models/sam/modeling_sam.py | 6 +++--- src/transformers/models/sam/modeling_tf_sam.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 5b963386b79b..a8e50b3c946d 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1343,12 +1343,12 @@ def forward( output_hidden_states=output_hidden_states, return_dict=True, ) - image_embeddings = vision_outputs['last_hidden_state'] + image_embeddings = vision_outputs["last_hidden_state"] if output_hidden_states: - vision_hidden_states = vision_outputs['hidden_states'] + vision_hidden_states = vision_outputs["hidden_states"] if output_attentions: - vision_attentions = vision_outputs['attentions'] + vision_attentions = vision_outputs["attentions"] if input_points is not None and input_labels is None: input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index af484a0ccf4f..98bf3edc0eb3 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -1425,12 +1425,12 @@ def call( output_hidden_states=output_hidden_states, return_dict=True, ) - image_embeddings = vision_outputs['last_hidden_state'] + image_embeddings = vision_outputs["last_hidden_state"] if output_hidden_states: - vision_hidden_states = vision_outputs['hidden_states'] + vision_hidden_states = vision_outputs["hidden_states"] if output_attentions: - vision_attentions = vision_outputs['attentions'] + vision_attentions = vision_outputs["attentions"] if input_points is not None and input_labels is None: input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) From ce669d4d57cb325fd9897fb9c30c8742dbc4840a Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 18:59:16 +0100 Subject: [PATCH 32/49] Revert changes to common tests and just override check_pt_tf_outputs --- tests/models/sam/test_modeling_tf_sam.py | 11 +++++++++-- tests/test_modeling_tf_common.py | 14 +++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 688d74d063f5..4583e826264a 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -424,8 +424,15 @@ def test_model_from_pretrained(self): model = TFSamModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) - def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): - super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None): + super().check_pt_tf_outputs( + tf_outputs=tf_outputs, + pt_outputs=pt_outputs, + model_class=model_class, + tol=tol, + name=name, + attributes=attributes, + ) def prepare_image(): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 277738b39026..220560d9238a 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -640,7 +640,7 @@ def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): return pt_inputs_dict - def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) # send pytorch inputs to the correct device @@ -665,10 +665,10 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): if tf_loss is not None: tf_outputs.loss = tf.math.reduce_mean(tf_loss) - self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model), tol=tol) + self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model)) @is_pt_tf_cross_test - def test_pt_tf_model_equivalence(self, allow_missing_keys=False, tol=1e-5): + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): import transformers for model_class in self.all_model_classes: @@ -711,10 +711,10 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False, tol=1e-5): ) # Original test: check without `labels` - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) # check with `labels` if tf_inputs_dict_with_labels: - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels, tol=tol) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: @@ -731,10 +731,10 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False, tol=1e-5): ) # Original test: check without `labels` - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) # check with `labels` if tf_inputs_dict_with_labels: - self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels, tol=tol) + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 0ff4dc7ebf51a03073b120339c271efe45ae6d96 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 4 May 2023 19:05:28 +0100 Subject: [PATCH 33/49] Revert changes to other model tests --- tests/models/data2vec/test_modeling_tf_data2vec_vision.py | 3 --- tests/models/vit_mae/test_modeling_tf_vit_mae.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py index 6d4ca2eea083..dfa890d25a9e 100644 --- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py @@ -459,9 +459,6 @@ def test_model_from_pretrained(self): model = TFData2VecVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) - def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): - super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) - # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 1449fc0b7597..53d68b644ac8 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -267,7 +267,7 @@ def prepare_numpy_arrays(inputs_dict): # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test - def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): # make masks reproducible np.random.seed(2) @@ -279,7 +279,7 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): # PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument tf_inputs_dict["noise"] = tf_noise - super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol) + super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) # overwrite from common since TFViTMAEForPretraining outputs loss along with # logits and mask indices. loss and mask indices are not suitable for integration From 34bca0fe0592466229b188f80b756d9e2633dbc1 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 5 May 2023 12:33:09 +0100 Subject: [PATCH 34/49] Clarify comments for functional layernorm --- src/transformers/tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 7aa64e3743d5..d0d13eff2ed3 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -78,7 +78,7 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") - # Calculate the moments on the last axis (layer activations). + # Get mean and variance on the axis to be normalized mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) if axis != -1: From 74f3291e024e9ab4977a07c4f357b8a8a86f8c61 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 9 May 2023 16:53:46 +0100 Subject: [PATCH 35/49] Add missing transpose from PT code --- src/transformers/models/sam/modeling_tf_sam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 98bf3edc0eb3..0b25bd06b6bd 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -528,6 +528,7 @@ def call( mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) + image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) upscaled_embedding = self.upscale_conv1(image_embeddings) upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) From fcfef1fede07403b930a965c11b6b89cc01e326f Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 10 May 2023 17:49:37 +0100 Subject: [PATCH 36/49] Removed unused copied from in PT code --- src/transformers/models/sam/modeling_sam.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index a8e50b3c946d..fadad302cb70 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput): mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings class SamPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial From 392486d9693bf4d1bea81665d6440d1c34a8cae3 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 10 May 2023 17:53:23 +0100 Subject: [PATCH 37/49] Remove overrides for tests that don't exist in TF --- tests/models/sam/test_modeling_tf_sam.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 4583e826264a..b6df3c6a4ec5 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -394,26 +394,6 @@ def test_attention_outputs(self): list(expected_mask_decoder_attention_shape), ) - @unittest.skip(reason="SamModel does not support training") - def test_training(self): - pass - - @unittest.skip(reason="SamModel does not support training") - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - - @unittest.skip(reason="SamModel does not support training") - def test_retain_grad_hidden_states_attentions(self): - pass - @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") def test_hidden_states_output(self): pass From b7f9dd4512e4ec28055d615f635881825e69ee48 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 May 2023 17:32:04 +0100 Subject: [PATCH 38/49] Fix transpose and update tests for PT and TF to check pred_masks --- src/transformers/models/sam/modeling_tf_sam.py | 2 +- tests/models/sam/test_modeling_sam.py | 13 +++++++++++-- tests/models/sam/test_modeling_tf_sam.py | 12 ++++++++++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 0b25bd06b6bd..3484a8509ffa 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -527,8 +527,8 @@ def call( iou_token_out = point_embedding[:, :, 0, :] mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] - image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) + image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) upscaled_embedding = self.upscale_conv1(image_embeddings) upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index a9f7c204fccd..5eb933031be8 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -473,8 +473,10 @@ def test_inference_mask_generation_no_point(self): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4)) + self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4)) def test_inference_mask_generation_one_point_one_bb(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") @@ -494,8 +496,12 @@ def test_inference_mask_generation_one_point_one_bb(self): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4)) + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4) + ) def test_inference_mask_generation_batched_points_batched_images(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") @@ -517,6 +523,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze().cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() EXPECTED_SCORES = torch.tensor( [ @@ -534,7 +541,9 @@ def test_inference_mask_generation_batched_points_batched_images(self): ], ] ) + EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406]) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) def test_inference_mask_generation_one_point_one_bb_zero(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index b6df3c6a4ec5..64dc270c5c84 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -438,7 +438,10 @@ def test_inference_mask_generation_no_point(self): outputs = model(**inputs) scores = tf.squeeze(outputs.iou_scores) - self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=1e-4)) + masks = outputs.pred_masks[0, 0, 0, 0, :3] + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4)) + self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2)) def test_inference_mask_generation_one_point_one_bb(self): model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) @@ -452,8 +455,10 @@ def test_inference_mask_generation_one_point_one_bb(self): outputs = model(**inputs) scores = tf.squeeze(outputs.iou_scores) + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=1e-4)) + self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4)) + self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2)) def test_inference_mask_generation_batched_points_batched_images(self): model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) @@ -469,6 +474,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): outputs = model(**inputs) scores = tf.squeeze(outputs.iou_scores) + masks = outputs.pred_masks[0, 0, 0, 0, :3] EXPECTED_SCORES = np.array( [ @@ -486,7 +492,9 @@ def test_inference_mask_generation_batched_points_batched_images(self): ], ] ) + EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406]) self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2)) def test_inference_mask_generation_one_point_one_bb_zero(self): model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) From 4f966c8e92900f6f8071b2357995cf8a74f16d7b Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 May 2023 18:09:27 +0100 Subject: [PATCH 39/49] Add training flag --- src/transformers/models/sam/modeling_tf_sam.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 3484a8509ffa..e67d82497cbf 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -900,7 +900,7 @@ def add_decomposed_rel_pos( attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width)) return attn - def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: batch_size, height, width, _ = shape_list(hidden_states) # qkv with shape (3, batch_size, nHead, height * width, channel) qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) @@ -918,8 +918,10 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False) -> tf.Tensor: attn_weights = tf.nn.softmax(attn_weights, axis=-1) - # attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) - attn_probs = attn_weights + if training: + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + else: + attn_probs = attn_weights attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) @@ -983,6 +985,7 @@ def call( self, hidden_states: tf.Tensor, output_attentions: Optional[bool] = False, + training: Optional[bool] = False, ) -> Tuple[tf.Tensor]: residual = hidden_states @@ -994,6 +997,7 @@ def call( hidden_states, attn_weights = self.attn( hidden_states=hidden_states, output_attentions=output_attentions, + training=training, ) if self.window_size > 0: hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) @@ -1086,6 +1090,7 @@ def call( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + training: Optional[bool] = False, ) -> Union[Tuple, TFSamVisionEncoderOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1107,7 +1112,7 @@ def call( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) hidden_states = layer_outputs[0] @@ -1367,6 +1372,7 @@ def call( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict=None, + training=False, **kwargs, ) -> List[Dict[str, tf.Tensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1425,6 +1431,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + training=training, ) image_embeddings = vision_outputs["last_hidden_state"] From f29b1091df4c337a2bc216a901b3808040ab15c6 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 May 2023 18:36:37 +0100 Subject: [PATCH 40/49] Update tests to use TF checkpoints --- tests/models/sam/test_modeling_tf_sam.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 64dc270c5c84..4282d679a67d 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -401,7 +401,7 @@ def test_hidden_states_output(self): @slow def test_model_from_pretrained(self): for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = TFSamModel.from_pretrained(model_name, from_pt=True) + model = TFSamModel.from_pretrained(model_name) self.assertIsNotNone(model) def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None): @@ -430,7 +430,7 @@ def prepare_dog_img(): @slow class SamModelIntegrationTest(unittest.TestCase): def test_inference_mask_generation_no_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -444,7 +444,7 @@ def test_inference_mask_generation_no_point(self): self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2)) def test_inference_mask_generation_one_point_one_bb(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -461,7 +461,7 @@ def test_inference_mask_generation_one_point_one_bb(self): self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2)) def test_inference_mask_generation_batched_points_batched_images(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -497,7 +497,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2)) def test_inference_mask_generation_one_point_one_bb_zero(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -519,7 +519,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4)) def test_inference_mask_generation_one_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -545,7 +545,7 @@ def test_inference_mask_generation_one_point(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4)) def test_inference_mask_generation_two_points(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -568,7 +568,7 @@ def test_inference_mask_generation_two_points(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) def test_inference_mask_generation_two_points_batched(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -587,7 +587,7 @@ def test_inference_mask_generation_two_points_batched(self): self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4)) def test_inference_mask_generation_one_box(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -602,7 +602,7 @@ def test_inference_mask_generation_one_box(self): self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4)) def test_inference_mask_generation_batched_image_one_point(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -624,7 +624,7 @@ def test_inference_mask_generation_batched_image_one_point(self): self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4)) def test_inference_mask_generation_two_points_point_batch(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() @@ -651,7 +651,7 @@ def test_inference_mask_generation_two_points_point_batch(self): ) def test_inference_mask_generation_three_boxes_point_batch(self): - model = TFSamModel.from_pretrained("facebook/sam-vit-huge", from_pt=True) + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() From 792f376abdfe6145ab6436f194190cb6732a76ae Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 May 2023 18:39:16 +0100 Subject: [PATCH 41/49] Update index.mdx --- docs/source/en/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index d0a3c0babbbe..9f6cff89bc77 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow. | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RWKV | ❌ | ❌ | ✅ | ❌ | ❌ | -| SAM | ❌ | ❌ | ✅ | ❌ | ❌ | +| SAM | ❌ | ❌ | ✅ | ✅ | ❌ | | SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | From 7249cb5a1739a21cdf17363234b5df41b7933106 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 12 May 2023 13:10:56 +0100 Subject: [PATCH 42/49] Add missing cross-test decorator --- tests/models/sam/test_processor_sam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index a837ebaea459..d509b9921aa5 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -257,6 +257,7 @@ def test_post_process_masks_equivalence(self): self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) + @is_pt_tf_cross_test def test_image_processor_equivalence(self): image_processor = self.get_image_processor() From 28dac3eeddc18ae132569aadffc74a00881cd424 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 16:00:22 +0100 Subject: [PATCH 43/49] Remove optional extra asterisks --- src/transformers/models/sam/modeling_sam.py | 8 ++++---- src/transformers/models/sam/modeling_tf_sam.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index fadad302cb70..62be2a83cece 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -475,7 +475,7 @@ def forward( the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. - output_attentions (bool, **optional**): + output_attentions (bool, *optional*): Whether or not to return the attentions tensors of all attention layers. """ batch_size, num_channels, height, width = image_embeddings.shape @@ -667,11 +667,11 @@ def forward( Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (`torch.Tensor`, **optional**): + points (`torch.Tensor`, *optional*): point coordinates and labels to embed. - boxes (`torch.Tensor`, **optional**): + boxes (`torch.Tensor`, *optional*): boxes to embed - masks (`torch.Tensor`, **optional**): + masks (`torch.Tensor`, *optional*): masks to embed """ sparse_embeddings = None diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index e67d82497cbf..2540c783834a 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -744,11 +744,11 @@ def call( Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (`tf.Tensor`, **optional**): + points (`tf.Tensor`, *optional*): point coordinates and labels to embed. - boxes (`tf.Tensor`, **optional**): + boxes (`tf.Tensor`, *optional*): boxes to embed - masks (`tf.Tensor`, **optional**): + masks (`tf.Tensor`, *optional*): masks to embed """ sparse_embeddings = None From bbc688665feca3b0db209794ad8e230e889af964 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 16:01:58 +0100 Subject: [PATCH 44/49] Revert return_dict changes in PT code --- src/transformers/models/sam/modeling_sam.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 62be2a83cece..8f47b22a4c55 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1340,14 +1340,14 @@ def forward( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=True, + return_dict=return_dict, ) - image_embeddings = vision_outputs["last_hidden_state"] + image_embeddings = vision_outputs[0] if output_hidden_states: - vision_hidden_states = vision_outputs["hidden_states"] + vision_hidden_states = vision_outputs[1] if output_attentions: - vision_attentions = vision_outputs["attentions"] + vision_attentions = vision_outputs[-1] if input_points is not None and input_labels is None: input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) From 79d2b81c6841a4afd29ea730b82701540882609f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 16:03:21 +0100 Subject: [PATCH 45/49] Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/sam/modeling_tf_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 2540c783834a..d87d9635073e 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -257,7 +257,7 @@ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): def __init__( self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs - ) -> None: + ): """ A transformer block with four layers: (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on From 3d59612f5759f754e1b5e89f3ee4e17e762bb7c5 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 16:04:38 +0100 Subject: [PATCH 46/49] Remove None return annotations on init methods --- src/transformers/models/sam/modeling_sam.py | 10 +++++----- src/transformers/models/sam/modeling_tf_sam.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 8f47b22a4c55..7df461175097 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -197,7 +197,7 @@ class SamAttention(nn.Module): values. """ - def __init__(self, config, downsample_rate=None) -> None: + def __init__(self, config, downsample_rate=None): super().__init__() self.hidden_size = config.hidden_size @@ -251,7 +251,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: class SamTwoWayAttentionBlock(nn.Module): - def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None: + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): """ A transformer block with four layers: (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on @@ -706,7 +706,7 @@ def forward( class SamVisionAttention(nn.Module): """Multi-head Attention block with relative position embeddings.""" - def __init__(self, config, window_size) -> None: + def __init__(self, config, window_size): super().__init__() input_size = ( (config.image_size // config.patch_size, config.image_size // config.patch_size) @@ -844,7 +844,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch class SamVisionLayer(nn.Module): - def __init__(self, config, window_size) -> None: + def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = SamVisionAttention(config, window_size) @@ -1165,7 +1165,7 @@ def _init_weights(self, module): class SamModel(SamPreTrainedModel): _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] - def __init__(self, config) -> None: + def __init__(self, config): super().__init__(config) self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index d87d9635073e..ca20e137e58d 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -195,7 +195,7 @@ class TFSamAttention(tf.keras.layers.Layer): values. """ - def __init__(self, config, downsample_rate=None, **kwargs) -> None: + def __init__(self, config, downsample_rate=None, **kwargs): super().__init__(**kwargs) self.hidden_size = config.hidden_size @@ -785,7 +785,7 @@ def call( class TFSamVisionAttention(tf.keras.layers.Layer): """Multi-head Attention block with relative position embeddings.""" - def __init__(self, config, window_size, **kwargs) -> None: + def __init__(self, config, window_size, **kwargs): super().__init__(**kwargs) input_size = ( (config.image_size // config.patch_size, config.image_size // config.patch_size) @@ -938,7 +938,7 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False class TFSamVisionLayer(tf.keras.layers.Layer): - def __init__(self, config, window_size, **kwargs) -> None: + def __init__(self, config, window_size, **kwargs): super().__init__(**kwargs) self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") self.attn = TFSamVisionAttention(config, window_size, name="attn") @@ -1271,7 +1271,7 @@ def serving(self, inputs): class TFSamModel(TFSamPreTrainedModel): _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] - def __init__(self, config, **kwargs) -> None: + def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") From ee4057fee7c998b480136b9e90536187c9849be8 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 17:45:06 +0100 Subject: [PATCH 47/49] Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/models/sam/test_processor_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index d509b9921aa5..7d669bb96914 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -169,7 +169,7 @@ def test_image_processor(self): input_processor = processor(images=image_input, return_tensors="np") input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor - input_feat_extract.pop("reshaped_input_sizes") # pop original_sizes as it is popped in the processor + input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) From 3902969ddb7e841faf2527895318cdcea156a707 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 18:03:21 +0100 Subject: [PATCH 48/49] Fix input_boxes shapes --- src/transformers/models/sam/processing_sam.py | 4 ++-- tests/models/sam/test_modeling_tf_sam.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index d0b3caf0b239..fd73260f9401 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -223,7 +223,7 @@ def _check_and_preprocess_points( input_points = input_points.numpy().tolist() if not isinstance(input_points, list) or not isinstance(input_points[0], list): - raise ValueError("Input points must be a list of list of floating integers.") + raise ValueError("Input points must be a list of list of floating points.") input_points = [np.array(input_point) for input_point in input_points] else: input_points = None @@ -247,7 +247,7 @@ def _check_and_preprocess_points( or not isinstance(input_boxes[0], list) or not isinstance(input_boxes[0][0], list) ): - raise ValueError("Input boxes must be a list of list of list of floating integers.") + raise ValueError("Input boxes must be a list of list of list of floating points.") input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] else: input_boxes = None diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 4282d679a67d..a07398365fff 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -448,7 +448,7 @@ def test_inference_mask_generation_one_point_one_bb(self): processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() - input_boxes = [[650, 900, 1000, 1250]] + input_boxes = [[[650, 900, 1000, 1250]]] input_points = [[[820, 1080]]] inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf") @@ -501,7 +501,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self): processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") raw_image = prepare_image() - input_boxes = [[620, 900, 1000, 1255]] + input_boxes = [[[620, 900, 1000, 1255]]] input_points = [[[820, 1080]]] labels = [[0]] From 07813f0f66960c2398a6dbc20934bc54a048dfb6 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 18:43:07 +0100 Subject: [PATCH 49/49] make fixup --- src/transformers/models/sam/modeling_tf_sam.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index ca20e137e58d..ddd8e526a79a 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -255,9 +255,7 @@ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): - def __init__( - self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs - ): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): """ A transformer block with four layers: (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on