From 75eaedb9f6bc088ecd1ac70397fc99e34fbe49e7 Mon Sep 17 00:00:00 2001 From: MathieuJouffroy Date: Tue, 9 Aug 2022 12:31:25 +0200 Subject: [PATCH 01/15] implemented TFCvtModel and TFCvtForImageClassification and modified relevant files, added an exception in convert_tf_weight_name_to_pt_weight_name, added quick testing file to compare with pytorch model --- src/transformers/__init__.py | 8 + src/transformers/modeling_tf_pytorch_utils.py | 5 +- .../models/auto/modeling_tf_auto.py | 2 + src/transformers/models/cvt/__init__.py | 22 +- .../models/cvt/modeling_tf_cvt.py | 779 ++++++++++++++++++ src/transformers/models/cvt/test_tf.py | 55 ++ 6 files changed, 869 insertions(+), 2 deletions(-) create mode 100644 src/transformers/models/cvt/modeling_tf_cvt.py create mode 100644 src/transformers/models/cvt/test_tf.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8a982969fc1e..b1c2cc75d885 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2357,6 +2357,13 @@ "TFCTRLPreTrainedModel", ] ) + _import_structure["models.cvt"].extend( + [ + "TFCvtForImageClassification", + "TFCvtModel", + "TFCvtPreTrainedModel", + ] + ) _import_structure["models.data2vec"].extend( [ "TFData2VecVisionForImageClassification", @@ -5022,6 +5029,7 @@ TFCTRLModel, TFCTRLPreTrainedModel, ) + from .models.cvt import TFCvtForImageClassification, TFCvtModel, TFCvtPreTrainedModel from .models.data2vec import ( TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation, diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 73d6a7613fda..d360c6d6f3da 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -88,7 +88,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel": - tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") + if tf_name[0] == "cvt": + tf_name[-1] = "weight" + else: + tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") # Remove prefix if needed tf_name = ".".join(tf_name) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index db462aa62186..95fe3f8aaaec 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -39,6 +39,7 @@ ("convbert", "TFConvBertModel"), ("convnext", "TFConvNextModel"), ("ctrl", "TFCTRLModel"), + ("cvt", "TFCvtModel"), ("data2vec-vision", "TFData2VecVisionModel"), ("deberta", "TFDebertaModel"), ("deberta-v2", "TFDebertaV2Model"), @@ -184,6 +185,7 @@ [ # Model for Image-classsification ("convnext", "TFConvNextForImageClassification"), + ("cvt", "TFCvtForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), ("mobilevit", "TFMobileViTForImageClassification"), diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py index 36a6f69824ef..86dfb82832a9 100644 --- a/src/transformers/models/cvt/__init__.py +++ b/src/transformers/models/cvt/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available _import_structure = {"configuration_cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"]} @@ -36,6 +36,17 @@ "CvtPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_cvt"] = [ + "TFCvtForImageClassification", + "TFCvtModel", + "TFCvtPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig @@ -53,6 +64,15 @@ CvtPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_cvt import TFCvtForImageClassification, TFCvtModel, TFCvtPreTrainedModel + + else: import sys diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py new file mode 100644 index 000000000000..7064e904bde2 --- /dev/null +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -0,0 +1,779 @@ +""" TF 2.0 Cvt model.""" + + +import collections.abc +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import tensorflow as tf + +from ...modeling_outputs import ModelOutput +from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_cvt import CvtConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "CvtConfig" + + +@dataclass +class TFBaseModelOutputWithCLSToken(ModelOutput): + """ + Base class for model's outputs. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + cls_token_value (`tf.Tensor` of shape `(batch_size, 1, num_channels)`): + Classification token at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + """ + + last_hidden_state: tf.Tensor = None + cls_token_value: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + + +class TFCvtDropPath(tf.keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + def call(self, x: tf.Tensor, training=None): + if training: + keep_prob = 1 - self.drop_prob + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFCvtEmbeddings(tf.keras.layers.Layer): + def __init__( + self, + config: CvtConfig, + patch_size: int, + embed_dim: int, + stride: int, + padding: int, + dropout_rate: float, + **kwargs + ): + super().__init__(**kwargs) + self.convolution_embeddings = TFCvtConvEmbeddings( + config, + patch_size=patch_size, + embed_dim=embed_dim, + stride=stride, + padding=padding, + name="convolution_embeddings", + ) + self.dropout = tf.keras.layers.Dropout(dropout_rate) + + def call(self, pixel_values: tf.Tensor) -> tf.Tensor: + hidden_state = self.convolution_embeddings(pixel_values) + hidden_state = self.dropout(hidden_state) + return hidden_state + + +class TFCvtConvEmbeddings(tf.keras.layers.Layer): + def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs): + super().__init__(**kwargs) + self.pad_value = padding + self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + self.proj = tf.keras.layers.Conv2D( + filters=embed_dim, + kernel_size=patch_size, + strides=stride, + padding="valid", + data_format="channels_last", + kernel_initializer=get_initializer(config.initializer_range), + name="projection", + ) + # Using the same default epsilon & momentum as PyTorch + self.Normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") + + def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: + # Custom padding to match the model implementation in PyTorch + height_pad = width_pad = (self.pad_value, self.pad_value) + hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) + hidden_state = self.proj(hidden_state) + return hidden_state + + def call(self, pixel_values: tf.Tensor) -> tf.Tensor: + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + pixel_values = self.convolution(pixel_values) + # rearrange "b h w c" -> b (h w) c" + batch_size, height, width, num_channels = shape_list(pixel_values) + hidden_size = height * width + pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) + if self.Normalization: + pixel_values = self.Normalization(pixel_values) + # rearrange "b (h w) c" -> b c h w" + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 1)) + pixel_values = tf.reshape(pixel_values, shape=(batch_size, num_channels, height, width)) + return pixel_values + + +class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer): + def __init__(self, config: CvtConfig, kernel_size: int, stride: int, padding: int, **kwargs): + super().__init__(**kwargs) + self.pad_value = padding + self.conv = tf.keras.layers.DepthwiseConv2D( + kernel_size=kernel_size, + kernel_initializer=get_initializer(config.initializer_range), + padding="valid", + strides=stride, + use_bias=False, + name="convolution", + ) + # Using the same default epsilon & momentum as PyTorch + self.Normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + + def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: + # Custom padding to match the model implementation in PyTorch + height_pad = width_pad = (self.pad_value, self.pad_value) + hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) + hidden_state = self.conv(hidden_state) + return hidden_state + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution(hidden_state) + hidden_state = self.Normalization(hidden_state, training=training) + return hidden_state + + +class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer): + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: + # rearrange " b c h w -> b (h w) c" + batch_size, num_channels, height, width = shape_list(hidden_state) + hidden_size = height * width + hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size)) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) + return hidden_state + + +class TFCvtSelfAttentionProjection(tf.keras.layers.Layer): + def __init__( + self, + config: CvtConfig, + kernel_size: int, + stride: int, + padding: int, + projection_method: str = "dw_bn", + **kwargs + ): + super().__init__(**kwargs) + if projection_method == "dw_bn": + self.convolution_projection = TFCvtSelfAttentionConvProjection( + config, kernel_size, stride, padding, name="convolution_projection" + ) + self.linear_projection = TFCvtSelfAttentionLinearProjection() + + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + hidden_state = tf.transpose(hidden_state, perm=(0, 3, 2, 1)) + hidden_state = self.convolution_projection(hidden_state) + hidden_state = tf.transpose(hidden_state, perm=(0, 3, 2, 1)) + hidden_state = self.linear_projection(hidden_state) + return hidden_state + + +class TFCvtSelfAttention(tf.keras.layers.Layer): + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + attention_drop_rate: float, + with_cls_token: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.scale = embed_dim**-0.5 + self.with_cls_token = with_cls_token + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.convolution_projection_query = TFCvtSelfAttentionProjection( + config, + kernel_size, + stride_q, + padding_q, + projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method, + name="convolution_projection_query", + ) + self.convolution_projection_key = TFCvtSelfAttentionProjection( + config, + kernel_size, + stride_kv, + padding_kv, + projection_method=qkv_projection_method, + name="convolution_projection_key", + ) + self.convolution_projection_value = TFCvtSelfAttentionProjection( + config, + kernel_size, + stride_kv, + padding_kv, + projection_method=qkv_projection_method, + name="convolution_projection_value", + ) + + self.projection_query = tf.keras.layers.Dense( + units=self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=True, + bias_initializer="zeros", + name="projection_query", + ) + self.projection_key = tf.keras.layers.Dense( + units=self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=True, + bias_initializer="zeros", + name="projection_key", + ) + self.projection_value = tf.keras.layers.Dense( + units=self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=True, + bias_initializer="zeros", + name="projection_value", + ) + self.dropout = tf.keras.layers.Dropout(attention_drop_rate) + + def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor: + batch_size, hidden_size, _ = shape_list(hidden_state) + head_dim = self.embed_dim // self.num_heads + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim)) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) + return hidden_state + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) -> tf.Tensor: + if self.with_cls_token: + cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) + + # rearrange "b (h w) c -> b c h w" + batch_size, hidden_size, num_channels = shape_list(hidden_state) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) + + key = self.convolution_projection_key(hidden_state) + query = self.convolution_projection_query(hidden_state) + value = self.convolution_projection_value(hidden_state) + + if self.with_cls_token: + query = tf.concat((cls_token, query), axis=1) + key = tf.concat((cls_token, key), axis=1) + value = tf.concat((cls_token, value), axis=1) + + head_dim = self.embed_dim // self.num_heads + query = self.rearrange_for_multi_head_attention(self.projection_query(query)) + key = self.rearrange_for_multi_head_attention(self.projection_key(key)) + value = self.rearrange_for_multi_head_attention(self.projection_value(value)) + + attention_score = tf.matmul(query, key, transpose_b=True) * self.scale + attention_probs = stable_softmax(logits=attention_score, axis=-1) + attention_probs = self.dropout(attention_probs, training=training) + context = tf.matmul(attention_probs, value) + + # rearrange "b h t d -> b t (h d)" + _, _, hidden_size, _ = shape_list(context) + context = tf.transpose(context, perm=(0, 2, 1, 3)) + context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim)) + return context + + +class TFCvtSelfOutput(tf.keras.layers.Layer): + """Output of the Attention layer.""" + + def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense( + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = tf.keras.layers.Dropout(drop_rate) + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.dense(inputs=hidden_state) + hidden_state = self.dropout(inputs=hidden_state, training=training) + return hidden_state + + +class TFCvtAttention(tf.keras.layers.Layer): + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + attention_drop_rate: float, + drop_rate: float, + with_cls_token: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.attention = TFCvtSelfAttention( + config, + num_heads, + embed_dim, + kernel_size, + stride_q, + stride_kv, + padding_q, + padding_kv, + qkv_projection_method, + attention_drop_rate, + with_cls_token, + name="attention", + ) + self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool): + self_output = self.attention(hidden_state, height, width, training) + attention_output = self.dense_output(self_output, training) + return attention_output + + +class TFCvtIntermediate(tf.keras.layers.Layer): + def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense( + units=int(embed_dim * mlp_ratio), + kernel_initializer=get_initializer(config.initializer_range), + activation="gelu", + name="dense", + ) + + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: + hidden_state = self.dense(hidden_state) + return hidden_state + + +class TFCvtOutput(tf.keras.layers.Layer): + def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: int, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense( + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = tf.keras.layers.Dropout(drop_rate) + + def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: + hidden_state = self.dense(inputs=hidden_state) + hidden_state = self.dropout(inputs=hidden_state) + hidden_state = hidden_state + input_tensor + return hidden_state + + +class TFCvtLayer(tf.keras.layers.Layer): + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + attention_drop_rate: float, + drop_rate: float, + mlp_ratio: float, + drop_path_rate: float, + with_cls_token: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.attention = TFCvtAttention( + config, + num_heads, + embed_dim, + kernel_size, + stride_q, + stride_kv, + padding_q, + padding_kv, + qkv_projection_method, + attention_drop_rate, + drop_rate, + with_cls_token, + name="attention", + ) + self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate") + self.dense_output = TFCvtOutput(config, embed_dim, drop_rate, name="output") + # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour. + self.drop_path = ( + TFCvtDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 + else tf.keras.layers.Activation("linear", name="drop_path") + ) + self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before") + self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after") + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: + self_attention_output = self.attention( + # in Cvt, layernorm is applied before self-attention + self.layernorm_before(hidden_state), + width, + height, + ) + attention_output = self_attention_output + attention_output = self.drop_path(attention_output) + + # first residual connection + hidden_state = attention_output + hidden_state + + # in Cvt, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_state) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.dense_output(layer_output, hidden_state) + layer_output = self.drop_path(layer_output, training=training) + return layer_output + + +class TFCvtStage(tf.keras.layers.Layer): + def __init__(self, config: CvtConfig, stage: int, **kwargs): + super().__init__(**kwargs) + self.config = config + self.stage = stage + if self.config.cls_token[self.stage]: + self.cls_token = self.add_weight( + shape=(1, 1, self.config.embed_dim[-1]), + initializer="zeros", + trainable=True, + name="cvt.encoder.stages.2.cls_token", + ) + self.embedding = TFCvtEmbeddings( + self.config, + patch_size=config.patch_sizes[self.stage], + stride=config.patch_stride[self.stage], + embed_dim=config.embed_dim[self.stage], + padding=config.patch_padding[self.stage], + dropout_rate=config.drop_rate[self.stage], + name="embedding", + ) + + drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage]) + drop_path_rates = [x.numpy().item() for x in drop_path_rates] + self.layers = [ + TFCvtLayer( + config, + num_heads=config.num_heads[self.stage], + embed_dim=config.embed_dim[self.stage], + kernel_size=config.kernel_qkv[self.stage], + stride_q=config.stride_q[self.stage], + stride_kv=config.stride_kv[self.stage], + padding_q=config.padding_q[self.stage], + padding_kv=config.padding_kv[self.stage], + qkv_projection_method=config.qkv_projection_method[self.stage], + attention_drop_rate=config.attention_drop_rate[self.stage], + drop_rate=config.drop_rate[self.stage], + mlp_ratio=config.mlp_ratio[self.stage], + drop_path_rate=drop_path_rates[self.stage], + with_cls_token=config.cls_token[self.stage], + name=f"layers.{j}", + ) + for j in range(config.depth[self.stage]) + ] + + def call(self, hidden_state: tf.Tensor): + cls_token = None + hidden_state = self.embedding(hidden_state) + + batch_size, num_channels, height, width = shape_list(hidden_state) + # rearrange b c h w -> b (h w) c" + hidden_size = height * width + hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size)) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) + + if self.config.cls_token[self.stage]: + cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0) + hidden_state = tf.concat((cls_token, hidden_state), axis=1) + + for layer in self.layers: + layer_outputs = layer(hidden_state, height, width) + hidden_state = layer_outputs + + if self.config.cls_token[self.stage]: + cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) + + # rearrange -> b (h w) c" -> b c h w + hidden_state = tf.transpose(hidden_state, (0, 2, 1)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) + return hidden_state, cls_token + + +class TFCvtEncoder(tf.keras.layers.Layer): + config_class = CvtConfig + + def __init__(self, config: CvtConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.stages = [ + TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth)) + ] + + def call( + self, + pixel_values: TFModelInputType, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + hidden_state = pixel_values + + cls_token = None + for _, (stage_module) in enumerate(self.stages): + hidden_state, cls_token = stage_module(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=hidden_state, + cls_token_value=cls_token, + hidden_states=all_hidden_states, + ) + + +@keras_serializable +class TFCvtMainLayer(tf.keras.layers.Layer): + config_class = CvtConfig + + def __init__(self, config: CvtConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.encoder = TFCvtEncoder(config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + # pixel_values = tf.transpose(pixel_values, perm=(0, 3, 1, 2)) + # tried reshaping to to `NHWC` directly in main layer and using this format + # throughout the model, but even though I get the same predictions as torch + # CVT model, our sequence_output have an absolute difference > 100 + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + # encoder_outputs -> [last_hidden_state, cls_token, all_hidden_states] + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=sequence_output, + cls_token_value=encoder_outputs.cls_token_value, + hidden_states=encoder_outputs.hidden_states, + ) + + +class TFCvtPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CvtConfig + base_model_prefix = "cvt" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32 + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + return self.serving_output(output) + + +class TFCvtModel(TFCvtPreTrainedModel): + def __init__(self, config: CvtConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.cvt = TFCvtMainLayer(config, name="cvt") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.cvt( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + # outputs -> [last_hidden_sate, cls_token_value, hidden_states] + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=outputs.last_hidden_state, + cls_token_value=outputs.cls_token_value, + hidden_states=outputs.hidden_states, + ) + + def serving_output(self, output: TFBaseModelOutputWithCLSToken) -> TFBaseModelOutputWithCLSToken: + return TFBaseModelOutputWithCLSToken( + last_hidden_state=output.last_hidden_state, + cls_token_value=output.cls_token_value, + hidden_states=output.hidden_states, + ) + + +class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: CvtConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.cvt = TFCvtMainLayer(config, name="cvt") + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + # Classifier head + self.classifier = tf.keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=True, + bias_initializer="zeros", + name="classifier", + ) + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]: + + outputs = self.cvt( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + cls_token = outputs[1] + if self.config.cls_token[-1]: + sequence_output = self.LayerNorm(cls_token) + else: + # rearrange "b c h w -> b (h w) c" + batch_size, num_channels, height, width = shape_list(sequence_output) + sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width)) + sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1)) + sequence_output = self.LayerNorm(sequence_output) + + sequence_output_mean = tf.reduce_mean(sequence_output, axis=1) + logits = self.classifier(sequence_output_mean) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + # outputs -> [last_hidden_sate, cls_token_value, hidden_states] + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states + ) + + def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention: + return TFImageClassifierOutputWithNoAttention( + logits=output.logits, + hidden_states=output.hidden_states + ) diff --git a/src/transformers/models/cvt/test_tf.py b/src/transformers/models/cvt/test_tf.py new file mode 100644 index 000000000000..d26fd123f790 --- /dev/null +++ b/src/transformers/models/cvt/test_tf.py @@ -0,0 +1,55 @@ +from transformers import AutoFeatureExtractor, CvtForImageClassification, CvtModel, CvtConfig + +# from modeling_cvt import CvtForImageClassification +from transformers import ResNetForImageClassification, ResNetModel, TFResNetForImageClassification, TFResNetModel +from transformers import TFCvtForImageClassification, TFCvtModel +import tensorflow as tf +from datasets import load_dataset +from transformers.modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model +import numpy as np +import json +from collections import OrderedDict +import torch + + +dataset = load_dataset("huggingface/cats-image") +image = dataset["test"]["image"][0] +feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") +pt_inputs = feature_extractor(image, return_tensors="pt") +tf_inputs = feature_extractor(image, return_tensors="tf") + +print("\n--------- CVT Classification ------------\n") +# PYTORCH: +# pt_model = CvtForImageClassification.from_pretrained('./pytorch_model.bin', config='./config.json') +pt_model = CvtForImageClassification.from_pretrained("microsoft/cvt-13") +with torch.no_grad(): + pt_logits = pt_model(**pt_inputs).logits +pt_predicted_label = pt_logits.argmax(-1).item() + +##TENSORFLOW: microsoft/cvt-13 +# tf_model = TFCvtForImageClassification.from_pretrained('./pytorch_model.bin', config='./config.json', from_pt=True) +tf_model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13", from_pt=True) +tf_logits = tf_model(**tf_inputs).logits +tf_predicted_label = int(tf.math.argmax(tf_logits, axis=-1)) + +print(f"TF input shape: {tf_inputs['pixel_values'].shape}") +print(f"PT input shape: {pt_inputs['pixel_values'].shape}") +print(f"TF Predicted label: {tf_model.config.id2label[tf_predicted_label]}") +print(f"PT Predicted label: {pt_model.config.id2label[pt_predicted_label]}") + +print("\n--------- Model ------------\n") +# PYTORCH: +model = CvtModel.from_pretrained("microsoft/cvt-13") +with torch.no_grad(): + outputs = model(**pt_inputs) +last_hidden_states = outputs.last_hidden_state +np_pt = last_hidden_states.numpy() + +##TENSORFLOW: +tf_model = TFCvtModel.from_pretrained("microsoft/cvt-13", from_pt=True) +tfo = tf_model(**tf_inputs, training=False) # build the network +np_tf = tfo.last_hidden_state.numpy() + +assert np_pt.shape == np_tf.shape +diff = np.amax(np.abs(np_pt - np_tf)) +print(f"\nMax absolute difference between models outputs {diff}\n") From eb8bd90493031a6df59fc9671e50aa7f7ce43ee1 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 9 Aug 2022 14:30:01 +0200 Subject: [PATCH 02/15] added docstring + testing file in transformers testing suite --- .../models/cvt/modeling_tf_cvt.py | 174 +++++++++++++++- src/transformers/models/cvt/test_tf.py | 16 +- tests/models/cvt/test_modeling_tf_cvt.py | 195 ++++++++++++++++++ 3 files changed, 370 insertions(+), 15 deletions(-) create mode 100644 tests/models/cvt/test_modeling_tf_cvt.py diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 7064e904bde2..cc6a1be3d11d 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -70,6 +70,8 @@ def call(self, x: tf.Tensor, training=None): class TFCvtEmbeddings(tf.keras.layers.Layer): + """Construct the Convolutional Token Embeddings.""" + def __init__( self, config: CvtConfig, @@ -98,6 +100,8 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor: class TFCvtConvEmbeddings(tf.keras.layers.Layer): + """Image to Conv Embedding. This convolutional operation aims to model local spatial contexts.""" + def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) self.pad_value = padding @@ -142,6 +146,8 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor: class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer): + """Convolutional projection layer.""" + def __init__(self, config: CvtConfig, kernel_size: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) self.pad_value = padding @@ -170,6 +176,8 @@ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer): + """Linear projection layer used to flatten tokens into 1D.""" + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: # rearrange " b c h w -> b (h w) c" batch_size, num_channels, height, width = shape_list(hidden_state) @@ -180,6 +188,8 @@ def call(self, hidden_state: tf.Tensor) -> tf.Tensor: class TFCvtSelfAttentionProjection(tf.keras.layers.Layer): + """Convolutional Projection for Attention.""" + def __init__( self, config: CvtConfig, @@ -207,6 +217,11 @@ def call(self, hidden_state: tf.Tensor) -> tf.Tensor: class TFCvtSelfAttention(tf.keras.layers.Layer): + """ + Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), + is applied for query, key, and value embeddings. + """ + def __init__( self, config: CvtConfig, @@ -319,7 +334,7 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) class TFCvtSelfOutput(tf.keras.layers.Layer): - """Output of the Attention layer.""" + """Output of the Attention layer (MLP).""" def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs): super().__init__(**kwargs) @@ -335,6 +350,8 @@ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: class TFCvtAttention(tf.keras.layers.Layer): + """Attention layer. First chunk of the convolutional transformer block.""" + def __init__( self, config: CvtConfig, @@ -378,6 +395,8 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) class TFCvtIntermediate(tf.keras.layers.Layer): + """Intermediate dense layer. Second chunk of the convolutional transformer block.""" + def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs): super().__init__(**kwargs) self.dense = tf.keras.layers.Dense( @@ -393,6 +412,10 @@ def call(self, hidden_state: tf.Tensor) -> tf.Tensor: class TFCvtOutput(tf.keras.layers.Layer): + """ + Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection. + """ + def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: int, **kwargs): super().__init__(**kwargs) self.dense = tf.keras.layers.Dense( @@ -408,6 +431,12 @@ def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: class TFCvtLayer(tf.keras.layers.Layer): + """ + Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It + consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the + `Block` class in the original implementation. + """ + def __init__( self, config: CvtConfig, @@ -450,6 +479,7 @@ def __init__( if drop_path_rate > 0.0 else tf.keras.layers.Activation("linear", name="drop_path") ) + # Using the same default epsilon as PyTorch self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before") self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after") @@ -477,6 +507,17 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool class TFCvtStage(tf.keras.layers.Layer): + """ + Cvt stage (encoder block). Each stage has 2 parts : + - (1) A Convolutional Token Embedding layer + - (2) A Convolutional Transformer Block (layer). + The classification token is added only in the last stage. + + Args: + config ([`CvtConfig`]): Model configuration class. + stage (`int`): Stage number. + """ + def __init__(self, config: CvtConfig, stage: int, **kwargs): super().__init__(**kwargs) self.config = config @@ -549,6 +590,14 @@ def call(self, hidden_state: tf.Tensor): class TFCvtEncoder(tf.keras.layers.Layer): + """ + Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers + (depth) being 1, 2 and 10. + + Args: + config ([`CvtConfig`]): Model configuration class. + """ + config_class = CvtConfig def __init__(self, config: CvtConfig, **kwargs): @@ -585,6 +634,8 @@ def call( @keras_serializable class TFCvtMainLayer(tf.keras.layers.Layer): + """Construct the Cvt model.""" + config_class = CvtConfig def __init__(self, config: CvtConfig, **kwargs): @@ -605,7 +656,7 @@ def call( # pixel_values = tf.transpose(pixel_values, perm=(0, 3, 1, 2)) # tried reshaping to to `NHWC` directly in main layer and using this format # throughout the model, but even though I get the same predictions as torch - # CVT model, our sequence_output have an absolute difference > 100 + # CVT model, our sequence_output has an absolute difference > 100 with torch prediction encoder_outputs = self.encoder( pixel_values, @@ -669,6 +720,57 @@ def serving(self, inputs): return self.serving_output(output) +TFCVT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + + + Args: + config ([`CvtConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TFCVT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.", + TFCVT_START_DOCSTRING, +) class TFCvtModel(TFCvtPreTrainedModel): def __init__(self, config: CvtConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -676,6 +778,8 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): self.cvt = TFCvtMainLayer(config, name="cvt") @unpack_inputs + @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC) def call( self, pixel_values: Optional[tf.Tensor] = None, @@ -683,6 +787,26 @@ def call( return_dict: Optional[bool] = None, training: bool = False, ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, TFCvtModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") + >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13", from_pt=True) + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -712,14 +836,25 @@ def serving_output(self, output: TFBaseModelOutputWithCLSToken) -> TFBaseModelOu ) +@add_start_docstrings( + """ + Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + TFCVT_START_DOCSTRING, +) class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss): def __init__(self, config: CvtConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.cvt = TFCvtMainLayer(config, name="cvt") - self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - + # In the original implementation the authors use epsilon=1e-5 for Layer Normalization. + # Pytorch CVT model doesn't seem to use config.layer_norm_ep + # Therefore we will be using the same epsilon as in Pytorch CVT model. + # What is the use of config.layer_norm_eps ? + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") + # Classifier head self.classifier = tf.keras.layers.Dense( units=config.num_labels, @@ -730,6 +865,8 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): ) @unpack_inputs + @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) def call( self, pixel_values: Optional[tf.Tensor] = None, @@ -738,6 +875,35 @@ def call( return_dict: Optional[bool] = None, training: Optional[bool] = False, ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, TFCvtForImageClassification + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") + >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13", from_pt=True) + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] + >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) + ```""" outputs = self.cvt( pixel_values, diff --git a/src/transformers/models/cvt/test_tf.py b/src/transformers/models/cvt/test_tf.py index d26fd123f790..31f73376325b 100644 --- a/src/transformers/models/cvt/test_tf.py +++ b/src/transformers/models/cvt/test_tf.py @@ -1,16 +1,10 @@ -from transformers import AutoFeatureExtractor, CvtForImageClassification, CvtModel, CvtConfig - -# from modeling_cvt import CvtForImageClassification -from transformers import ResNetForImageClassification, ResNetModel, TFResNetForImageClassification, TFResNetModel -from transformers import TFCvtForImageClassification, TFCvtModel -import tensorflow as tf -from datasets import load_dataset -from transformers.modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model -import numpy as np import json -from collections import OrderedDict +import numpy as np import torch - +import tensorflow as tf +from transformers import AutoFeatureExtractor, CvtForImageClassification, CvtModel +from transformers import TFCvtForImageClassification, TFCvtModel +from datasets import load_dataset dataset = load_dataset("huggingface/cats-image") image = dataset["test"]["image"][0] diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py new file mode 100644 index 000000000000..25494c06f6c3 --- /dev/null +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -0,0 +1,195 @@ +""" Testing suite for the Tensorflow CvT model. """ + + +import inspect +import unittest +from math import floor + +import numpy as np + +from transformers import CvtConfig +from transformers.testing_utils import require_tf, require_vision, slow +from transformers.utils import cached_property, is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFCvtForImageClassification, TFCvtModel + from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class TFCvtModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=64, + num_channels=3, + embed_dim=[16, 48, 96], + num_heads=[1, 3, 6], + depth=[1, 2, 10], + patch_sizes=[7, 3, 3], + patch_stride=[4, 2, 2], + patch_padding=[2, 1, 1], + stride_kv=[2, 2, 2], + cls_token=[False, False, True], + attention_drop_rate=[0.0, 0.0, 0.0], + initializer_range=0.02, + layer_norm_eps=1e-12, + is_training=True, + use_labels=True, + num_labels=2, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_sizes = patch_sizes + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.is_training = is_training + self.use_labels = use_labels + self.num_labels = num_labels + self.num_channels = num_channels + self.embed_dim = embed_dim + self.num_heads = num_heads + self.stride_kv = stride_kv + self.depth = depth + self.cls_token = cls_token + self.attention_drop_rate = attention_drop_rate + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + # create a random int32 tensor of given shape + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + return config, pixel_values, labels + + def get_config(self): + return CvtConfig( + image_size=self.image_size, + num_labels=self.num_labels, + num_channels=self.num_channels, + embed_dim=self.embed_dim, + num_heads=self.num_heads, + patch_sizes=self.patch_sizes, + patch_padding=self.patch_padding, + patch_stride=self.patch_stride, + stride_kv=self.stride_kv, + depth=self.depth, + cls_token=self.cls_token, + attention_drop_rate=self.attention_drop_rate, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFCvtModel(config=config) + result = model(pixel_values, training=False) + image_size = (self.image_size, self.image_size) + height, width = image_size[0], image_size[1] + for i in range(len(self.depth)): + height = floor(((height + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1) + width = floor(((width + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dim[-1], height, width)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = TFCvtForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFCvtModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as Cvt + does not use input_ids, inputs_embeds, attention_mask and seq_length. + """ + + all_model_classes = (TFCvtModel, TFCvtForImageClassification) if is_tf_available() else () + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + test_onnx = False + + def setUp(self): + self.model_tester = TFCvtModelTester(self) + self.config_tester = ConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="Cvt does not output attentions") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="Cvt does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Cvt does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFCvtModel.from_pretrained(model_name, from_pt=True) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image \ No newline at end of file From e4ed73ff00a20ef1bab2ececf1b66c21dc88b8ec Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 9 Aug 2022 17:05:26 +0200 Subject: [PATCH 03/15] added test in testing file, modified docs to pass repo-consistency, passed formatting test --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/cvt.mdx | 11 +++ .../models/cvt/modeling_tf_cvt.py | 23 +++---- src/transformers/models/cvt/test_tf.py | 24 ++++--- src/transformers/utils/dummy_tf_objects.py | 21 ++++++ tests/models/cvt/test_modeling_tf_cvt.py | 67 ++++++++++++++++++- 6 files changed, 122 insertions(+), 26 deletions(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index db2f4c843f32..a96440a29d7e 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -225,7 +225,7 @@ Flax), PyTorch, and/or TensorFlow. | ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ConvNeXT | ❌ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | -| CvT | ❌ | ❌ | ✅ | ❌ | ❌ | +| CvT | ❌ | ❌ | ✅ | ✅ | ❌ | | Data2VecAudio | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecVision | ❌ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/cvt.mdx b/docs/source/en/model_doc/cvt.mdx index 84be7e39a550..a46ae68d58f8 100644 --- a/docs/source/en/model_doc/cvt.mdx +++ b/docs/source/en/model_doc/cvt.mdx @@ -51,3 +51,14 @@ This model was contributed by [anugunj](https://huggingface.co/anugunj). The ori [[autodoc]] CvtForImageClassification - forward + +## TFCvtModel + +[[autodoc]] TFCvtModel + - call + +## TFCvtForImageClassification + +[[autodoc]] TFCvtForImageClassification + - call + diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index cc6a1be3d11d..1afc0444938b 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -218,8 +218,8 @@ def call(self, hidden_state: tf.Tensor) -> tf.Tensor: class TFCvtSelfAttention(tf.keras.layers.Layer): """ - Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), - is applied for query, key, and value embeddings. + Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for + query, key, and value embeddings. """ def __init__( @@ -508,7 +508,7 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool class TFCvtStage(tf.keras.layers.Layer): """ - Cvt stage (encoder block). Each stage has 2 parts : + Cvt stage (encoder block). Each stage has 2 parts : - (1) A Convolutional Token Embedding layer - (2) A Convolutional Transformer Block (layer). The classification token is added only in the last stage. @@ -655,9 +655,9 @@ def call( raise ValueError("You have to specify pixel_values") # pixel_values = tf.transpose(pixel_values, perm=(0, 3, 1, 2)) # tried reshaping to to `NHWC` directly in main layer and using this format - # throughout the model, but even though I get the same predictions as torch + # throughout the model, but even though I get the same predictions as torch # CVT model, our sequence_output has an absolute difference > 100 with torch prediction - + encoder_outputs = self.encoder( pixel_values, output_hidden_states=output_hidden_states, @@ -854,7 +854,7 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): # Therefore we will be using the same epsilon as in Pytorch CVT model. # What is the use of config.layer_norm_eps ? self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") - + # Classifier head self.classifier = tf.keras.layers.Dense( units=config.num_labels, @@ -932,14 +932,7 @@ def call( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFImageClassifierOutputWithNoAttention( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states - ) + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention: - return TFImageClassifierOutputWithNoAttention( - logits=output.logits, - hidden_states=output.hidden_states - ) + return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states) diff --git a/src/transformers/models/cvt/test_tf.py b/src/transformers/models/cvt/test_tf.py index 31f73376325b..784d644c5413 100644 --- a/src/transformers/models/cvt/test_tf.py +++ b/src/transformers/models/cvt/test_tf.py @@ -1,27 +1,32 @@ -import json import numpy as np -import torch import tensorflow as tf -from transformers import AutoFeatureExtractor, CvtForImageClassification, CvtModel -from transformers import TFCvtForImageClassification, TFCvtModel +import torch from datasets import load_dataset +from transformers import ( + AutoFeatureExtractor, + CvtForImageClassification, + CvtModel, + TFCvtForImageClassification, + TFCvtModel, +) + + dataset = load_dataset("huggingface/cats-image") image = dataset["test"]["image"][0] feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") pt_inputs = feature_extractor(image, return_tensors="pt") tf_inputs = feature_extractor(image, return_tensors="tf") + print("\n--------- CVT Classification ------------\n") # PYTORCH: -# pt_model = CvtForImageClassification.from_pretrained('./pytorch_model.bin', config='./config.json') pt_model = CvtForImageClassification.from_pretrained("microsoft/cvt-13") with torch.no_grad(): pt_logits = pt_model(**pt_inputs).logits pt_predicted_label = pt_logits.argmax(-1).item() -##TENSORFLOW: microsoft/cvt-13 -# tf_model = TFCvtForImageClassification.from_pretrained('./pytorch_model.bin', config='./config.json', from_pt=True) +# TENSORFLOW: tf_model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13", from_pt=True) tf_logits = tf_model(**tf_inputs).logits tf_predicted_label = int(tf.math.argmax(tf_logits, axis=-1)) @@ -31,6 +36,7 @@ print(f"TF Predicted label: {tf_model.config.id2label[tf_predicted_label]}") print(f"PT Predicted label: {pt_model.config.id2label[pt_predicted_label]}") + print("\n--------- Model ------------\n") # PYTORCH: model = CvtModel.from_pretrained("microsoft/cvt-13") @@ -39,9 +45,9 @@ last_hidden_states = outputs.last_hidden_state np_pt = last_hidden_states.numpy() -##TENSORFLOW: +# TENSORFLOW: tf_model = TFCvtModel.from_pretrained("microsoft/cvt-13", from_pt=True) -tfo = tf_model(**tf_inputs, training=False) # build the network +tfo = tf_model(**tf_inputs, training=False) np_tf = tfo.last_hidden_state.numpy() assert np_pt.shape == np_tf.shape diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 7cec699498c9..be5ac1490f46 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -786,6 +786,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFCvtForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFCvtModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFCvtPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFData2VecVisionForImageClassification(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 25494c06f6c3..04986368026f 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -162,6 +162,13 @@ def test_inputs_embeds(self): def test_model_common_attributes(self): pass + @unittest.skipIf( + not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, + reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + ) + def test_dataset_conversion(self): + super().test_dataset_conversion() + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -174,6 +181,38 @@ def test_forward_signature(self): expected_arg_names = ["pixel_values"] self.assertListEqual(arg_names[:1], expected_arg_names) + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + hidden_states = outputs.hidden_states + + expected_num_layers = len(self.model_tester.depth) + self.assertEqual(len(hidden_states), expected_num_layers) + + # verify the first hidden states (first block) + self.assertListEqual( + list(hidden_states[0].shape[-3:]), + [ + self.model_tester.embed_dim[0], + self.model_tester.image_size // 4, + self.model_tester.image_size // 4, + ], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @@ -192,4 +231,30 @@ def test_model_from_pretrained(self): # We will verify our results on an image of cute cats def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - return image \ No newline at end of file + return image + + +@require_tf +@require_vision +class TFCvtModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + + @slow + def test_inference_image_classification_head(self): + model = TFCvtForImageClassification.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([0.9285, 0.9015, -0.3150]) + self.assertTrue(np.allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4)) From f4ded545fa9cc226f685c43543faecf58f5cba64 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Wed, 10 Aug 2022 18:17:53 +0200 Subject: [PATCH 04/15] refactoring + passing all test --- .../models/cvt/modeling_tf_cvt.py | 5 +- src/transformers/models/cvt/test_tf.py | 55 ------------------- 2 files changed, 3 insertions(+), 57 deletions(-) delete mode 100644 src/transformers/models/cvt/test_tf.py diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 1afc0444938b..6aad5f0db4b5 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -115,7 +115,7 @@ def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: i kernel_initializer=get_initializer(config.initializer_range), name="projection", ) - # Using the same default epsilon & momentum as PyTorch + # Using the same default epsilon as PyTorch self.Normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: @@ -133,12 +133,14 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor: # So change the input format from `NCHW` to `NHWC`. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) pixel_values = self.convolution(pixel_values) + # rearrange "b h w c" -> b (h w) c" batch_size, height, width, num_channels = shape_list(pixel_values) hidden_size = height * width pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) if self.Normalization: pixel_values = self.Normalization(pixel_values) + # rearrange "b (h w) c" -> b c h w" pixel_values = tf.transpose(pixel_values, perm=(0, 2, 1)) pixel_values = tf.reshape(pixel_values, shape=(batch_size, num_channels, height, width)) @@ -852,7 +854,6 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): # In the original implementation the authors use epsilon=1e-5 for Layer Normalization. # Pytorch CVT model doesn't seem to use config.layer_norm_ep # Therefore we will be using the same epsilon as in Pytorch CVT model. - # What is the use of config.layer_norm_eps ? self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") # Classifier head diff --git a/src/transformers/models/cvt/test_tf.py b/src/transformers/models/cvt/test_tf.py deleted file mode 100644 index 784d644c5413..000000000000 --- a/src/transformers/models/cvt/test_tf.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np -import tensorflow as tf -import torch -from datasets import load_dataset - -from transformers import ( - AutoFeatureExtractor, - CvtForImageClassification, - CvtModel, - TFCvtForImageClassification, - TFCvtModel, -) - - -dataset = load_dataset("huggingface/cats-image") -image = dataset["test"]["image"][0] -feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") -pt_inputs = feature_extractor(image, return_tensors="pt") -tf_inputs = feature_extractor(image, return_tensors="tf") - - -print("\n--------- CVT Classification ------------\n") -# PYTORCH: -pt_model = CvtForImageClassification.from_pretrained("microsoft/cvt-13") -with torch.no_grad(): - pt_logits = pt_model(**pt_inputs).logits -pt_predicted_label = pt_logits.argmax(-1).item() - -# TENSORFLOW: -tf_model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13", from_pt=True) -tf_logits = tf_model(**tf_inputs).logits -tf_predicted_label = int(tf.math.argmax(tf_logits, axis=-1)) - -print(f"TF input shape: {tf_inputs['pixel_values'].shape}") -print(f"PT input shape: {pt_inputs['pixel_values'].shape}") -print(f"TF Predicted label: {tf_model.config.id2label[tf_predicted_label]}") -print(f"PT Predicted label: {pt_model.config.id2label[pt_predicted_label]}") - - -print("\n--------- Model ------------\n") -# PYTORCH: -model = CvtModel.from_pretrained("microsoft/cvt-13") -with torch.no_grad(): - outputs = model(**pt_inputs) -last_hidden_states = outputs.last_hidden_state -np_pt = last_hidden_states.numpy() - -# TENSORFLOW: -tf_model = TFCvtModel.from_pretrained("microsoft/cvt-13", from_pt=True) -tfo = tf_model(**tf_inputs, training=False) -np_tf = tfo.last_hidden_state.numpy() - -assert np_pt.shape == np_tf.shape -diff = np.amax(np.abs(np_pt - np_tf)) -print(f"\nMax absolute difference between models outputs {diff}\n") From 9aa1843688683fcb30b55655391dbc431712e7db Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Thu, 11 Aug 2022 15:12:22 +0200 Subject: [PATCH 05/15] small refacto, removing unwanted comments --- src/transformers/models/cvt/modeling_tf_cvt.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 6aad5f0db4b5..9f3ce794ecf5 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -100,7 +100,7 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor: class TFCvtConvEmbeddings(tf.keras.layers.Layer): - """Image to Conv Embedding. This convolutional operation aims to model local spatial contexts.""" + """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts.""" def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) @@ -336,7 +336,7 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) class TFCvtSelfOutput(tf.keras.layers.Layer): - """Output of the Attention layer (MLP).""" + """Output of the Attention layer .""" def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs): super().__init__(**kwargs) @@ -655,10 +655,6 @@ def call( ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: if pixel_values is None: raise ValueError("You have to specify pixel_values") - # pixel_values = tf.transpose(pixel_values, perm=(0, 3, 1, 2)) - # tried reshaping to to `NHWC` directly in main layer and using this format - # throughout the model, but even though I get the same predictions as torch - # CVT model, our sequence_output has an absolute difference > 100 with torch prediction encoder_outputs = self.encoder( pixel_values, @@ -670,7 +666,6 @@ def call( sequence_output = encoder_outputs[0] if not return_dict: - # encoder_outputs -> [last_hidden_state, cls_token, all_hidden_states] return (sequence_output,) + encoder_outputs[1:] return TFBaseModelOutputWithCLSToken( @@ -821,7 +816,6 @@ def call( ) if not return_dict: - # outputs -> [last_hidden_sate, cls_token_value, hidden_states] return (outputs[0],) + outputs[1:] return TFBaseModelOutputWithCLSToken( @@ -851,9 +845,7 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): self.num_labels = config.num_labels self.cvt = TFCvtMainLayer(config, name="cvt") - # In the original implementation the authors use epsilon=1e-5 for Layer Normalization. - # Pytorch CVT model doesn't seem to use config.layer_norm_ep - # Therefore we will be using the same epsilon as in Pytorch CVT model. + # Using same default epsilon as in the original implementation. self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") # Classifier head @@ -929,7 +921,6 @@ def call( loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) if not return_dict: - # outputs -> [last_hidden_sate, cls_token_value, hidden_states] output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output From 28c8e0018df20f7b7b3123e945e11407ebaa61ed Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 16 Aug 2022 13:34:15 +0200 Subject: [PATCH 06/15] improved testing config --- src/transformers/models/cvt/modeling_tf_cvt.py | 6 +++--- tests/models/cvt/test_modeling_tf_cvt.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 9f3ce794ecf5..acbdb47e4bce 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -271,21 +271,21 @@ def __init__( ) self.projection_query = tf.keras.layers.Dense( - units=self.embed_dim, + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), use_bias=True, bias_initializer="zeros", name="projection_query", ) self.projection_key = tf.keras.layers.Dense( - units=self.embed_dim, + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), use_bias=True, bias_initializer="zeros", name="projection_key", ) self.projection_value = tf.keras.layers.Dense( - units=self.embed_dim, + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), use_bias=True, bias_initializer="zeros", diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 04986368026f..5e5ca883b46c 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -28,6 +28,13 @@ from transformers import AutoFeatureExtractor +class TFCvtConfigTester(ConfigTester): + def create_and_test_config_common_properties(self): + config = self.config_class(**self.inputs_dict) + self.parent.assertTrue(hasattr(config, "embed_dim")) + self.parent.assertTrue(hasattr(config, "num_heads")) + + class TFCvtModelTester: def __init__( self, @@ -136,10 +143,10 @@ class TFCvtModelTest(TFModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = TFCvtModelTester(self) - self.config_tester = ConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37) + self.config_tester = TFCvtConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37) def test_config(self): - self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_common_properties() self.config_tester.create_and_test_config_to_json_string() self.config_tester.create_and_test_config_to_json_file() self.config_tester.create_and_test_config_from_and_save_pretrained() @@ -147,9 +154,6 @@ def test_config(self): self.config_tester.check_config_can_be_init_without_params() self.config_tester.check_config_arguments_init() - def create_and_test_config_common_properties(self): - return - @unittest.skip(reason="Cvt does not output attentions") def test_attention_outputs(self): pass From 7cd0a380f2c91828bd0dd0d04157aae7984173d7 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 16 Aug 2022 15:47:14 +0200 Subject: [PATCH 07/15] corrected import error --- src/transformers/models/cvt/modeling_tf_cvt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index acbdb47e4bce..b3596b90a285 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -7,7 +7,6 @@ import tensorflow as tf -from ...modeling_outputs import ModelOutput from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention from ...modeling_tf_utils import ( TFModelInputType, @@ -18,7 +17,13 @@ unpack_inputs, ) from ...tf_utils import shape_list, stable_softmax -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_cvt import CvtConfig From b3e59ad343ccee451947eafe36a85834a622cade Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 16 Aug 2022 16:48:23 +0200 Subject: [PATCH 08/15] modified acces to pretrained model archive list, to pass tf_test --- src/transformers/models/cvt/__init__.py | 7 ++++++- src/transformers/models/cvt/modeling_tf_cvt.py | 9 +++++++++ tests/models/cvt/test_modeling_tf_cvt.py | 8 ++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py index 86dfb82832a9..079727157994 100644 --- a/src/transformers/models/cvt/__init__.py +++ b/src/transformers/models/cvt/__init__.py @@ -70,7 +70,12 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_tf_cvt import TFCvtForImageClassification, TFCvtModel, TFCvtPreTrainedModel + from .modeling_tf_cvt import ( + TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCvtForImageClassification, + TFCvtModel, + TFCvtPreTrainedModel, + ) else: diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index b3596b90a285..0d302e7bf9d9 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -32,6 +32,15 @@ # General docstring _CONFIG_FOR_DOC = "CvtConfig" +TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/cvt-13", + "microsoft/cvt-13-384", + "microsoft/cvt-13-384-22k", + "microsoft/cvt-21", + "microsoft/cvt-21-384", + "microsoft/cvt-21-384-22k", + # See all Cvt models at https://huggingface.co/models?filter=cvt +] @dataclass class TFBaseModelOutputWithCLSToken(ModelOutput): diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 5e5ca883b46c..aeab6060d341 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -19,7 +19,7 @@ import tensorflow as tf from transformers import TFCvtForImageClassification, TFCvtModel - from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.models.cvt.modeling_tf_cvt import TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -227,7 +227,7 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): - for model_name in CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFCvtModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) @@ -243,11 +243,11 @@ def prepare_img(): class TFCvtModelIntegrationTest(unittest.TestCase): @cached_property def default_feature_extractor(self): - return AutoFeatureExtractor.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + return AutoFeatureExtractor.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) @slow def test_inference_image_classification_head(self): - model = TFCvtForImageClassification.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + model = TFCvtForImageClassification.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) feature_extractor = self.default_feature_extractor image = prepare_img() From 70986b0295a5c4a1350be12e84b591b4afb808a0 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 16 Aug 2022 17:30:43 +0200 Subject: [PATCH 09/15] corrected import structure in init files --- src/transformers/__init__.py | 8 +++++++- src/transformers/models/cvt/__init__.py | 1 + src/transformers/models/cvt/modeling_tf_cvt.py | 1 + src/transformers/utils/dummy_tf_objects.py | 3 +++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b1c2cc75d885..cf937f2a162b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2359,6 +2359,7 @@ ) _import_structure["models.cvt"].extend( [ + "TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST", "TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel", @@ -5029,7 +5030,12 @@ TFCTRLModel, TFCTRLPreTrainedModel, ) - from .models.cvt import TFCvtForImageClassification, TFCvtModel, TFCvtPreTrainedModel + from .models.cvt import ( + TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCvtForImageClassification, + TFCvtModel, + TFCvtPreTrainedModel, + ) from .models.data2vec import ( TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation, diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py index 079727157994..66b18f334411 100644 --- a/src/transformers/models/cvt/__init__.py +++ b/src/transformers/models/cvt/__init__.py @@ -43,6 +43,7 @@ pass else: _import_structure["modeling_tf_cvt"] = [ + "TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST", "TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel", diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 0d302e7bf9d9..b0a7bc9addf3 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -42,6 +42,7 @@ # See all Cvt models at https://huggingface.co/models?filter=cvt ] + @dataclass class TFBaseModelOutputWithCLSToken(ModelOutput): """ diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index be5ac1490f46..dbfe14c7f633 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -786,6 +786,9 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + class TFCvtForImageClassification(metaclass=DummyObject): _backends = ["tf"] From 995803967d1e0c51477da7347538e4328f190ff9 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Fri, 26 Aug 2022 15:53:38 +0200 Subject: [PATCH 10/15] modified testing for keras_fit with cpu --- tests/models/cvt/test_modeling_tf_cvt.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index aeab6060d341..825e73587f0c 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -173,6 +173,13 @@ def test_model_common_attributes(self): def test_dataset_conversion(self): super().test_dataset_conversion() + @unittest.skipIf( + not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, + reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + ) + def test_keras_fit(self): + super().test_keras_fit() + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() From 1bbe27d866f35646ad151b19549163774966e962 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Thu, 1 Sep 2022 17:30:47 +0200 Subject: [PATCH 11/15] correcting PR issues + Refactoring --- src/transformers/modeling_tf_pytorch_utils.py | 5 +- .../models/cvt/modeling_tf_cvt.py | 157 ++++++++++-------- tests/models/cvt/test_modeling_tf_cvt.py | 8 +- 3 files changed, 94 insertions(+), 76 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index d360c6d6f3da..73d6a7613fda 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -88,10 +88,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel": - if tf_name[0] == "cvt": - tf_name[-1] = "weight" - else: - tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") + tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") # Remove prefix if needed tf_name = ".".join(tf_name) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index b0a7bc9addf3..aaaa7fd29969 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ TF 2.0 Cvt model.""" @@ -49,13 +63,13 @@ class TFBaseModelOutputWithCLSToken(ModelOutput): Base class for model's outputs. Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - cls_token_value (`tf.Tensor` of shape `(batch_size, 1, num_channels)`): + cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`): Classification token at the output of the last layer of the model. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. """ @@ -75,13 +89,13 @@ def __init__(self, drop_prob: float, **kwargs): self.drop_prob = drop_prob def call(self, x: tf.Tensor, training=None): - if training: - keep_prob = 1 - self.drop_prob - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x + if self.drop_prob == 0.0 or not training: + return x + keep_prob = 1 - self.drop_prob + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor class TFCvtEmbeddings(tf.keras.layers.Layer): @@ -108,9 +122,9 @@ def __init__( ) self.dropout = tf.keras.layers.Dropout(dropout_rate) - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_state = self.convolution_embeddings(pixel_values) - hidden_state = self.dropout(hidden_state) + hidden_state = self.dropout(hidden_state, training=training) return hidden_state @@ -121,7 +135,7 @@ def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: i super().__init__(**kwargs) self.pad_value = padding self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - self.proj = tf.keras.layers.Conv2D( + self.projection = tf.keras.layers.Conv2D( filters=embed_dim, kernel_size=patch_size, strides=stride, @@ -131,32 +145,31 @@ def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: i name="projection", ) # Using the same default epsilon as PyTorch - self.Normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") + self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: # Custom padding to match the model implementation in PyTorch height_pad = width_pad = (self.pad_value, self.pad_value) hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) - hidden_state = self.proj(hidden_state) + hidden_state = self.projection(hidden_state) return hidden_state def call(self, pixel_values: tf.Tensor) -> tf.Tensor: if isinstance(pixel_values, dict): pixel_values = pixel_values["pixel_values"] - # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) + # as input format. So change the input format to (batch_size, height, width, num_channels). pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) pixel_values = self.convolution(pixel_values) - # rearrange "b h w c" -> b (h w) c" + # rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels" batch_size, height, width, num_channels = shape_list(pixel_values) hidden_size = height * width pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) - if self.Normalization: - pixel_values = self.Normalization(pixel_values) + pixel_values = self.normalization(pixel_values) - # rearrange "b (h w) c" -> b c h w" + # rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" pixel_values = tf.transpose(pixel_values, perm=(0, 2, 1)) pixel_values = tf.reshape(pixel_values, shape=(batch_size, num_channels, height, width)) return pixel_values @@ -165,19 +178,21 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor: class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer): """Convolutional projection layer.""" - def __init__(self, config: CvtConfig, kernel_size: int, stride: int, padding: int, **kwargs): + def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) self.pad_value = padding - self.conv = tf.keras.layers.DepthwiseConv2D( + self.conv = tf.keras.layers.Conv2D( + filters=embed_dim, kernel_size=kernel_size, kernel_initializer=get_initializer(config.initializer_range), padding="valid", strides=stride, use_bias=False, name="convolution", + groups=embed_dim, ) # Using the same default epsilon & momentum as PyTorch - self.Normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: # Custom padding to match the model implementation in PyTorch @@ -188,7 +203,7 @@ def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_state = self.convolution(hidden_state) - hidden_state = self.Normalization(hidden_state, training=training) + hidden_state = self.normalization(hidden_state, training=training) return hidden_state @@ -196,11 +211,10 @@ class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer): """Linear projection layer used to flatten tokens into 1D.""" def call(self, hidden_state: tf.Tensor) -> tf.Tensor: - # rearrange " b c h w -> b (h w) c" - batch_size, num_channels, height, width = shape_list(hidden_state) + # rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels" + batch_size, height, width, num_channels = shape_list(hidden_state) hidden_size = height * width - hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size)) - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) return hidden_state @@ -210,6 +224,7 @@ class TFCvtSelfAttentionProjection(tf.keras.layers.Layer): def __init__( self, config: CvtConfig, + embed_dim: int, kernel_size: int, stride: int, padding: int, @@ -219,16 +234,15 @@ def __init__( super().__init__(**kwargs) if projection_method == "dw_bn": self.convolution_projection = TFCvtSelfAttentionConvProjection( - config, kernel_size, stride, padding, name="convolution_projection" + config, embed_dim, kernel_size, stride, padding, name="convolution_projection" ) self.linear_projection = TFCvtSelfAttentionLinearProjection() - def call(self, hidden_state: tf.Tensor) -> tf.Tensor: - # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - hidden_state = tf.transpose(hidden_state, perm=(0, 3, 2, 1)) - hidden_state = self.convolution_projection(hidden_state) - hidden_state = tf.transpose(hidden_state, perm=(0, 3, 2, 1)) + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) + # as input format. So change the input format to (batch_size, height, width, num_channels). + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1)) + hidden_state = self.convolution_projection(hidden_state, training=training) hidden_state = self.linear_projection(hidden_state) return hidden_state @@ -250,6 +264,7 @@ def __init__( padding_q: int, padding_kv: int, qkv_projection_method: str, + qkv_bias: bool, attention_drop_rate: float, with_cls_token: bool = True, **kwargs @@ -262,6 +277,7 @@ def __init__( self.convolution_projection_query = TFCvtSelfAttentionProjection( config, + embed_dim, kernel_size, stride_q, padding_q, @@ -270,6 +286,7 @@ def __init__( ) self.convolution_projection_key = TFCvtSelfAttentionProjection( config, + embed_dim, kernel_size, stride_kv, padding_kv, @@ -278,6 +295,7 @@ def __init__( ) self.convolution_projection_value = TFCvtSelfAttentionProjection( config, + embed_dim, kernel_size, stride_kv, padding_kv, @@ -288,21 +306,21 @@ def __init__( self.projection_query = tf.keras.layers.Dense( units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), - use_bias=True, + use_bias=qkv_bias, bias_initializer="zeros", name="projection_query", ) self.projection_key = tf.keras.layers.Dense( units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), - use_bias=True, + use_bias=qkv_bias, bias_initializer="zeros", name="projection_key", ) self.projection_value = tf.keras.layers.Dense( units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), - use_bias=True, + use_bias=qkv_bias, bias_initializer="zeros", name="projection_value", ) @@ -315,18 +333,18 @@ def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tens hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) return hidden_state - def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) -> tf.Tensor: + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: if self.with_cls_token: cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - # rearrange "b (h w) c -> b c h w" + # rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" batch_size, hidden_size, num_channels = shape_list(hidden_state) hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) - key = self.convolution_projection_key(hidden_state) - query = self.convolution_projection_query(hidden_state) - value = self.convolution_projection_value(hidden_state) + key = self.convolution_projection_key(hidden_state, training=training) + query = self.convolution_projection_query(hidden_state, training=training) + value = self.convolution_projection_value(hidden_state, training=training) if self.with_cls_token: query = tf.concat((cls_token, query), axis=1) @@ -343,7 +361,7 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool) attention_probs = self.dropout(attention_probs, training=training) context = tf.matmul(attention_probs, value) - # rearrange "b h t d -> b t (h d)" + # rearrange "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads, head_dim)" _, _, hidden_size, _ = shape_list(context) context = tf.transpose(context, perm=(0, 2, 1, 3)) context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim)) @@ -380,6 +398,7 @@ def __init__( padding_q: int, padding_kv: int, qkv_projection_method: str, + qkv_bias: bool, attention_drop_rate: float, drop_rate: float, with_cls_token: bool = True, @@ -396,6 +415,7 @@ def __init__( padding_q, padding_kv, qkv_projection_method, + qkv_bias, attention_drop_rate, with_cls_token, name="attention", @@ -405,9 +425,9 @@ def __init__( def prune_heads(self, heads): raise NotImplementedError - def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool): - self_output = self.attention(hidden_state, height, width, training) - attention_output = self.dense_output(self_output, training) + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False): + self_output = self.attention(hidden_state, height, width, training=training) + attention_output = self.dense_output(self_output, training=training) return attention_output @@ -440,9 +460,9 @@ def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: int, **kwargs): ) self.dropout = tf.keras.layers.Dropout(drop_rate) - def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: + def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_state = self.dense(inputs=hidden_state) - hidden_state = self.dropout(inputs=hidden_state) + hidden_state = self.dropout(inputs=hidden_state, training=training) hidden_state = hidden_state + input_tensor return hidden_state @@ -465,6 +485,7 @@ def __init__( padding_q: int, padding_kv: int, qkv_projection_method: str, + qkv_bias: bool, attention_drop_rate: float, drop_rate: float, mlp_ratio: float, @@ -483,6 +504,7 @@ def __init__( padding_q, padding_kv, qkv_projection_method, + qkv_bias, attention_drop_rate, drop_rate, with_cls_token, @@ -501,14 +523,9 @@ def __init__( self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after") def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: - self_attention_output = self.attention( - # in Cvt, layernorm is applied before self-attention - self.layernorm_before(hidden_state), - width, - height, - ) - attention_output = self_attention_output - attention_output = self.drop_path(attention_output) + # in Cvt, layernorm is applied before self-attention + attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training) + attention_output = self.drop_path(attention_output, training=training) # first residual connection hidden_state = attention_output + hidden_state @@ -569,6 +586,7 @@ def __init__(self, config: CvtConfig, stage: int, **kwargs): padding_q=config.padding_q[self.stage], padding_kv=config.padding_kv[self.stage], qkv_projection_method=config.qkv_projection_method[self.stage], + qkv_bias=config.qkv_bias[self.stage], attention_drop_rate=config.attention_drop_rate[self.stage], drop_rate=config.drop_rate[self.stage], mlp_ratio=config.mlp_ratio[self.stage], @@ -579,12 +597,12 @@ def __init__(self, config: CvtConfig, stage: int, **kwargs): for j in range(config.depth[self.stage]) ] - def call(self, hidden_state: tf.Tensor): + def call(self, hidden_state: tf.Tensor, training: bool = False): cls_token = None - hidden_state = self.embedding(hidden_state) + hidden_state = self.embedding(hidden_state, training) batch_size, num_channels, height, width = shape_list(hidden_state) - # rearrange b c h w -> b (h w) c" + # rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels" hidden_size = height * width hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size)) hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) @@ -594,13 +612,13 @@ def call(self, hidden_state: tf.Tensor): hidden_state = tf.concat((cls_token, hidden_state), axis=1) for layer in self.layers: - layer_outputs = layer(hidden_state, height, width) + layer_outputs = layer(hidden_state, height, width, training=training) hidden_state = layer_outputs if self.config.cls_token[self.stage]: cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - # rearrange -> b (h w) c" -> b c h w + # rearrange -> "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" hidden_state = tf.transpose(hidden_state, (0, 2, 1)) hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) return hidden_state, cls_token @@ -629,13 +647,14 @@ def call( pixel_values: TFModelInputType, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + training: Optional[bool] = False, ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None hidden_state = pixel_values cls_token = None for _, (stage_module) in enumerate(self.stages): - hidden_state, cls_token = stage_module(hidden_state) + hidden_state, cls_token = stage_module(hidden_state, training=training) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) @@ -666,7 +685,7 @@ def call( pixel_values: Optional[TFModelInputType] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - training: bool = False, + training: Optional[bool] = False, ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -797,7 +816,7 @@ def call( pixel_values: Optional[tf.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - training: bool = False, + training: Optional[bool] = False, ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: r""" Returns: @@ -813,7 +832,7 @@ def call( >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") - >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13", from_pt=True) + >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13") >>> inputs = feature_extractor(images=image, return_tensors="tf") >>> outputs = model(**inputs) @@ -903,7 +922,7 @@ def call( >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") - >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13", from_pt=True) + >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13") >>> inputs = feature_extractor(images=image, return_tensors="tf") >>> outputs = model(**inputs) @@ -925,7 +944,7 @@ def call( if self.config.cls_token[-1]: sequence_output = self.LayerNorm(cls_token) else: - # rearrange "b c h w -> b (h w) c" + # rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels" batch_size, num_channels, height, width = shape_list(sequence_output) sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width)) sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1)) diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 825e73587f0c..b2bd6a7175df 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -168,14 +168,14 @@ def test_model_common_attributes(self): @unittest.skipIf( not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, - reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + reason="TF does not support backprop for grouped convolutions on CPU.", ) def test_dataset_conversion(self): super().test_dataset_conversion() @unittest.skipIf( not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, - reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + reason="TF does not support backprop for grouped convolutions on CPU.", ) def test_keras_fit(self): super().test_keras_fit() @@ -235,6 +235,7 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): for model_name in TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + # Remove 'from_pt=True' after PR to add weights to hub (pt-to-tf) model = TFCvtModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) @@ -250,10 +251,11 @@ def prepare_img(): class TFCvtModelIntegrationTest(unittest.TestCase): @cached_property def default_feature_extractor(self): - return AutoFeatureExtractor.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + return AutoFeatureExtractor.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]) @slow def test_inference_image_classification_head(self): + # Remove 'from_pt=True' after PR to add weights to hub (pt-to-tf) model = TFCvtForImageClassification.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) feature_extractor = self.default_feature_extractor From 5a7ec2eca8647c27a68dc79a5690bb3037da6ef9 Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Fri, 9 Sep 2022 16:22:21 +0200 Subject: [PATCH 12/15] Refactoring : improving readability and reducing the number of permutations --- .../models/cvt/modeling_tf_cvt.py | 81 ++++++++----------- 1 file changed, 32 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index aaaa7fd29969..87ad12138008 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -133,7 +133,7 @@ class TFCvtConvEmbeddings(tf.keras.layers.Layer): def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) - self.pad_value = padding + self.padding = tf.keras.layers.ZeroPadding2D(padding=padding) self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) self.projection = tf.keras.layers.Conv2D( filters=embed_dim, @@ -147,31 +147,20 @@ def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: i # Using the same default epsilon as PyTorch self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") - def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: - # Custom padding to match the model implementation in PyTorch - height_pad = width_pad = (self.pad_value, self.pad_value) - hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) - hidden_state = self.projection(hidden_state) - return hidden_state - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: if isinstance(pixel_values, dict): pixel_values = pixel_values["pixel_values"] - # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) - # as input format. So change the input format to (batch_size, height, width, num_channels). - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - pixel_values = self.convolution(pixel_values) + pixel_values = self.projection(self.padding(pixel_values)) - # rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels" + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" batch_size, height, width, num_channels = shape_list(pixel_values) hidden_size = height * width pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) pixel_values = self.normalization(pixel_values) - # rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 1)) - pixel_values = tf.reshape(pixel_values, shape=(batch_size, num_channels, height, width)) + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" + pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels)) return pixel_values @@ -180,8 +169,8 @@ class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer): def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs): super().__init__(**kwargs) - self.pad_value = padding - self.conv = tf.keras.layers.Conv2D( + self.padding = tf.keras.layers.ZeroPadding2D(padding=padding) + self.convolution = tf.keras.layers.Conv2D( filters=embed_dim, kernel_size=kernel_size, kernel_initializer=get_initializer(config.initializer_range), @@ -194,15 +183,8 @@ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: # Using the same default epsilon & momentum as PyTorch self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") - def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: - # Custom padding to match the model implementation in PyTorch - height_pad = width_pad = (self.pad_value, self.pad_value) - hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) - hidden_state = self.conv(hidden_state) - return hidden_state - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution(hidden_state) + hidden_state = self.convolution(self.padding(hidden_state)) hidden_state = self.normalization(hidden_state, training=training) return hidden_state @@ -211,7 +193,7 @@ class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer): """Linear projection layer used to flatten tokens into 1D.""" def call(self, hidden_state: tf.Tensor) -> tf.Tensor: - # rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels" + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" batch_size, height, width, num_channels = shape_list(hidden_state) hidden_size = height * width hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) @@ -239,9 +221,6 @@ def __init__( self.linear_projection = TFCvtSelfAttentionLinearProjection() def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) - # as input format. So change the input format to (batch_size, height, width, num_channels). - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1)) hidden_state = self.convolution_projection(hidden_state, training=training) hidden_state = self.linear_projection(hidden_state) return hidden_state @@ -337,10 +316,9 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool if self.with_cls_token: cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - # rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" batch_size, hidden_size, num_channels = shape_list(hidden_state) - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) - hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) key = self.convolution_projection_key(hidden_state, training=training) query = self.convolution_projection_query(hidden_state, training=training) @@ -352,6 +330,7 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool value = tf.concat((cls_token, value), axis=1) head_dim = self.embed_dim // self.num_heads + query = self.rearrange_for_multi_head_attention(self.projection_query(query)) key = self.rearrange_for_multi_head_attention(self.projection_key(key)) value = self.rearrange_for_multi_head_attention(self.projection_value(value)) @@ -359,9 +338,9 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool attention_score = tf.matmul(query, key, transpose_b=True) * self.scale attention_probs = stable_softmax(logits=attention_score, axis=-1) attention_probs = self.dropout(attention_probs, training=training) - context = tf.matmul(attention_probs, value) - # rearrange "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads, head_dim)" + context = tf.matmul(attention_probs, value) + # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)" _, _, hidden_size, _ = shape_list(context) context = tf.transpose(context, perm=(0, 2, 1, 3)) context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim)) @@ -601,11 +580,10 @@ def call(self, hidden_state: tf.Tensor, training: bool = False): cls_token = None hidden_state = self.embedding(hidden_state, training) - batch_size, num_channels, height, width = shape_list(hidden_state) - # rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels" + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" + batch_size, height, width, num_channels = shape_list(hidden_state) hidden_size = height * width - hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size)) - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1)) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) if self.config.cls_token[self.stage]: cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0) @@ -618,9 +596,8 @@ def call(self, hidden_state: tf.Tensor, training: bool = False): if self.config.cls_token[self.stage]: cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - # rearrange -> "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width" - hidden_state = tf.transpose(hidden_state, (0, 2, 1)) - hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width)) + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" + hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) return hidden_state, cls_token @@ -651,6 +628,9 @@ def call( ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None hidden_state = pixel_values + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) + # as input format. So change the input format to (batch_size, height, width, num_channels). + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1)) cls_token = None for _, (stage_module) in enumerate(self.stages): @@ -658,6 +638,11 @@ def call( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) + # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules + hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2)) + if output_hidden_states: + all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states]) + if not return_dict: return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) @@ -727,9 +712,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: Returns: `Dict[str, tf.Tensor]`: The dummy inputs. """ - VISION_DUMMY_INPUTS = tf.random.uniform( - shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32 - ) + VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32) return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} @tf.function( @@ -880,7 +863,7 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs): self.num_labels = config.num_labels self.cvt = TFCvtMainLayer(config, name="cvt") # Using same default epsilon as in the original implementation. - self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") # Classifier head self.classifier = tf.keras.layers.Dense( @@ -942,13 +925,13 @@ def call( sequence_output = outputs[0] cls_token = outputs[1] if self.config.cls_token[-1]: - sequence_output = self.LayerNorm(cls_token) + sequence_output = self.layernorm(cls_token) else: - # rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels" + # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels" batch_size, num_channels, height, width = shape_list(sequence_output) sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width)) sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1)) - sequence_output = self.LayerNorm(sequence_output) + sequence_output = self.layernorm(sequence_output) sequence_output_mean = tf.reduce_mean(sequence_output, axis=1) logits = self.classifier(sequence_output_mean) From 5caea7fba4f48d69c996d5e4a65411bab83ac3af Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Fri, 7 Oct 2022 16:28:39 +0200 Subject: [PATCH 13/15] corrected momentum value + cls_token initialization --- src/transformers/models/cvt/modeling_tf_cvt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 87ad12138008..448bfd230288 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -180,8 +180,8 @@ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: name="convolution", groups=embed_dim, ) - # Using the same default epsilon & momentum as PyTorch - self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum) + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_state = self.convolution(self.padding(hidden_state)) @@ -538,10 +538,11 @@ def __init__(self, config: CvtConfig, stage: int, **kwargs): if self.config.cls_token[self.stage]: self.cls_token = self.add_weight( shape=(1, 1, self.config.embed_dim[-1]), - initializer="zeros", + initializer=get_initializer(self.config.initializer_range), trainable=True, name="cvt.encoder.stages.2.cls_token", ) + self.embedding = TFCvtEmbeddings( self.config, patch_size=config.patch_sizes[self.stage], From 4b84c80ff425691cb884659239fd0249003da20e Mon Sep 17 00:00:00 2001 From: mathieujouffroy Date: Tue, 11 Oct 2022 16:14:54 +0200 Subject: [PATCH 14/15] removed from_pt as weights were added to the hub --- tests/models/cvt/test_modeling_tf_cvt.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index b2bd6a7175df..13697c8eb313 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -235,8 +235,7 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): for model_name in TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - # Remove 'from_pt=True' after PR to add weights to hub (pt-to-tf) - model = TFCvtModel.from_pretrained(model_name, from_pt=True) + model = TFCvtModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -256,7 +255,7 @@ def default_feature_extractor(self): @slow def test_inference_image_classification_head(self): # Remove 'from_pt=True' after PR to add weights to hub (pt-to-tf) - model = TFCvtForImageClassification.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0], from_pt=True) + model = TFCvtForImageClassification.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]) feature_extractor = self.default_feature_extractor image = prepare_img() From 1e33bfc3b663ea59da7f40e95f734b5d354f8290 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 11 Oct 2022 16:43:27 +0100 Subject: [PATCH 15/15] Update tests/models/cvt/test_modeling_tf_cvt.py --- tests/models/cvt/test_modeling_tf_cvt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 13697c8eb313..9e261a5f25be 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -254,7 +254,6 @@ def default_feature_extractor(self): @slow def test_inference_image_classification_head(self): - # Remove 'from_pt=True' after PR to add weights to hub (pt-to-tf) model = TFCvtForImageClassification.from_pretrained(TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]) feature_extractor = self.default_feature_extractor