From 8b8b5a30835549be912682a51b16d4063d7f2488 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Mon, 28 Aug 2023 10:42:45 -0300 Subject: [PATCH 1/9] [feat] adding convnext architecture --- .../architectures/convnext.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tensorflow_similarity/architectures/convnext.py diff --git a/tensorflow_similarity/architectures/convnext.py b/tensorflow_similarity/architectures/convnext.py new file mode 100644 index 00000000..e8cd525c --- /dev/null +++ b/tensorflow_similarity/architectures/convnext.py @@ -0,0 +1,144 @@ +# Copyright 2021 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"ConvNeXt backbone for similarity learning" +from __future__ import annotations + +import re + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.applications import convnext + +from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding +from tensorflow_similarity.models import SimilarityModel + + +CONVNEXT_ARCHITECTURE = { + "TINY": convnext.ConvNeXtTiny, + "SMALL": convnext.ConvNeXtSmall, + "BASE": convnext.ConvNeXtBase, + "LARGE": convnext.ConvNeXtLarge, + "XLARGE": convnext.ConvNeXtXLarge, +} + + +def ConvNeXtSim( + input_shape: tuple[int, int, int], + embedding_size: int = 128, + variant: str = "BASE", + weights: str = "imagenet", + trainable: str = "frozen", + l2_norm: bool = True, + include_top: bool = True, + pooling: str = "gem", + gem_p: float = 3.0, +) -> SimilarityModel: + """ "Build an ConvNeXt Model backbone for similarity learning + [A ConvNet for the 2020s](https://arxiv.org/pdf/2201.03545.pdf) + Args: + input_shape: Size of the input image. Must match size of ConvNeXt version you use. + See below for version input size. + embedding_size: Size of the output embedding. Usually between 64 + and 512. Defaults to 128. + variant: Which Variant of the ConvNeXt to use. Defaults to "BASE". + weights: Use pre-trained weights - the only available currently being + imagenet. Defaults to "imagenet". + trainable: Make the ConvNeXt backbone fully trainable or partially + trainable. + - "full" to make the entire backbone trainable, + - "partial" to only make the last 3 block trainable + - "frozen" to make it not trainable. + l2_norm: If True and include_top is also True, then + tfsim.layers.MetricEmbedding is used as the last layer, otherwise + keras.layers.Dense is used. This should be true when using cosine + distance. Defaults to True. + include_top: Whether to include the fully-connected layer at the top + of the network. Defaults to True. + pooling: Optional pooling mode for feature extraction when + include_top is False. Defaults to gem. + - None means that the output of the model will be the 4D tensor + output of the last convolutional layer. + - avg means that global average pooling will be applied to the + output of the last convolutional layer, and thus the output of the + model will be a 2D tensor. + - max means that global max pooling will be applied. + - gem means that global GeneralizedMeanPooling2D will be applied. + The gem_p param sets the contrast amount on the pooling. + gem_p: Sets the power in the GeneralizedMeanPooling2D layer. A value + of 1.0 is equivalent to GlobalMeanPooling2D, while larger values + will increase the contrast between activations within each feature + map, and a value of math.inf will be equivalent to MaxPool2d. + """ + inputs = layers.Input(shape=input_shape) + x = inputs + + if variant not in CONVNEXT_ARCHITECTURE: + raise ValueError("Unknown ConvNeXt variant. Valid TINY BASE LARGE SMALL XLARGE") + + x = build_convnext(variant, weights, trainable)(x) + + if pooling == "gem": + x = GeneralizedMeanPooling2D(p=gem_p, name="gem_pool")(x) + elif pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + if include_top and pooling is not None: + if l2_norm: + outputs = MetricEmbedding(embedding_size)(x) + else: + outputs = layers.Dense(embedding_size)(x) + else: + outputs = x + + return SimilarityModel(inputs, outputs) + + +def build_convnext(variant: str, weights: str | None = None, trainable: str = "full") -> tf.keras.Model: + """Build the requested ConvNeXt + + Args: + variant: Which Variant of the ConvNeXt to use. + weights: Use pre-trained weights - the only available currently being + imagenet. + trainable: Make the ConvNeXt backbone fully trainable or partially + trainable. + - "full" to make the entire backbone trainable, + - "partial" to only make the last 3 block trainable + - "frozen" to make it not trainable. + Returns: + The output layer of the convnext model + """ + convnext_fn = CONVNEXT_ARCHITECTURE[variant.upper()] + convnext = convnext_fn(weights=weights, include_top=False) + + if trainable == "full": + convnext.trainable = True + elif trainable == "partial": + convnext.trainable = True + for layer in convnext.layers: + # freeze all layeres befor the last 3 blocks + if not re.search("^block[5,6,7]|^top", layer.name): + layer.trainable = False + elif trainable == "frozen": + convnext.trainable = False + else: + raise ValueError(f"{trainable} is not a supported option for 'trainable'.") + + if weights: + for layer in convnext.layers: + if isinstance(layer, layers.experimental.SyncBatchNormalization): + layer.trainable = False + return convnext From 94444781796dfd365ce857e2fc5c4a29cac9a8f2 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Mon, 28 Aug 2023 10:42:55 -0300 Subject: [PATCH 2/9] [feat] adding module to init --- tensorflow_similarity/architectures/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_similarity/architectures/__init__.py b/tensorflow_similarity/architectures/__init__.py index f2dc3b1a..74ed98b8 100644 --- a/tensorflow_similarity/architectures/__init__.py +++ b/tensorflow_similarity/architectures/__init__.py @@ -16,3 +16,4 @@ from .efficientnet import EfficientNetSim # noqa from .resnet18 import ResNet18Sim # noqa from .resnet50 import ResNet50Sim # noqa +from .convnext import ConvNeXtSim # noqa From 582079c9ede0b52699987132ce042438ab97ccc2 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Mon, 28 Aug 2023 10:43:03 -0300 Subject: [PATCH 3/9] [test] convnext architecture testing --- tests/architectures/test_convnext.py | 99 ++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/architectures/test_convnext.py diff --git a/tests/architectures/test_convnext.py b/tests/architectures/test_convnext.py new file mode 100644 index 00000000..5a352f86 --- /dev/null +++ b/tests/architectures/test_convnext.py @@ -0,0 +1,99 @@ +import re + +import pytest +import tensorflow as tf + +from tensorflow_similarity.architectures import convnext as convneXt + +TF_MAJOR_VERSION = int(tf.__version__.split(".")[0]) +TF_MINOR_VERSION = int(tf.__version__.split(".")[1]) + +def tf_version_check(major_version, minor_version): + if TF_MAJOR_VERSION <= major_version and TF_MINOR_VERSION < minor_version: + return True + + return False + +def test_build_convnext_tiny_full(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("tiny", "imagenet", "full")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_tiny" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 151 + expected_trainable_layer_count = 151 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + +def test_build_convnext_small_partial(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("small", "imagenet", "partial")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_small" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 0 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + +def test_build_convnext_base_frozen(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("base", "imagenet", "frozen")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_base" + assert not convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 0 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + +def test_build_convnext_large_full(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("large", "imagenet", "full")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_large" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 295 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count \ No newline at end of file From 8ca0e5f15aec66dbc2a37121492e068cc2d65dbb Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Fri, 8 Sep 2023 19:25:20 -0300 Subject: [PATCH 4/9] [fix] tf version check for convnext --- tensorflow_similarity/architectures/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_similarity/architectures/__init__.py b/tensorflow_similarity/architectures/__init__.py index 74ed98b8..75282370 100644 --- a/tensorflow_similarity/architectures/__init__.py +++ b/tensorflow_similarity/architectures/__init__.py @@ -16,4 +16,7 @@ from .efficientnet import EfficientNetSim # noqa from .resnet18 import ResNet18Sim # noqa from .resnet50 import ResNet50Sim # noqa -from .convnext import ConvNeXtSim # noqa +try: + from .convnext import ConvNeXtSim # noqa +except ImportError: + print("Warning: ConvNeXtSim not imported. This requires TensorFlow 2.10 or higher.") From a1f10e45f26ff41380eb9402fab29bc7faef0329 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Fri, 8 Sep 2023 19:25:35 -0300 Subject: [PATCH 5/9] [fix] convnext version check --- tensorflow_similarity/architectures/convnext.py | 14 +++++++++++++- tests/architectures/test_convnext.py | 9 ++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow_similarity/architectures/convnext.py b/tensorflow_similarity/architectures/convnext.py index e8cd525c..e499e949 100644 --- a/tensorflow_similarity/architectures/convnext.py +++ b/tensorflow_similarity/architectures/convnext.py @@ -18,8 +18,19 @@ import tensorflow as tf from tensorflow.keras import layers -from tensorflow.keras.applications import convnext +def convnext_exists_tf_version(): + tf_major_version = int(tf.__version__.split(".")[0]) + tf_minor_version = int(tf.__version__.split(".")[1]) + if 2 <= tf_major_version and 10 < tf_minor_version: + return True + return False + +if not convnext_exists_tf_version(): + raise ImportError(f"This code requires TensorFlow version 2.10 or higher. " + f"Please upgrade TensorFlow to use this code.") + +from tensorflow.keras.applications import convnext from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding from tensorflow_similarity.models import SimilarityModel @@ -33,6 +44,7 @@ } + def ConvNeXtSim( input_shape: tuple[int, int, int], embedding_size: int = 128, diff --git a/tests/architectures/test_convnext.py b/tests/architectures/test_convnext.py index 5a352f86..e61b03e3 100644 --- a/tests/architectures/test_convnext.py +++ b/tests/architectures/test_convnext.py @@ -3,7 +3,14 @@ import pytest import tensorflow as tf -from tensorflow_similarity.architectures import convnext as convneXt + +MIN_TF_MAJOR_VERSION = 2 +MIN_TF_MINOR_VERSION = 10 + +major_version = tf.__version__.split(".")[0] +minor_version = tf.__version__.split(".")[1] + +convneXt = pytest.importorskip("tensorflow_similarity.architectures.convnext") TF_MAJOR_VERSION = int(tf.__version__.split(".")[0]) TF_MINOR_VERSION = int(tf.__version__.split(".")[1]) From 300a1e14662cd33ea297bca1c0429ad81e9f3299 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Fri, 8 Sep 2023 19:30:53 -0300 Subject: [PATCH 6/9] [fix] remove useless version check --- tensorflow_similarity/architectures/convnext.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/tensorflow_similarity/architectures/convnext.py b/tensorflow_similarity/architectures/convnext.py index e499e949..4bc02960 100644 --- a/tensorflow_similarity/architectures/convnext.py +++ b/tensorflow_similarity/architectures/convnext.py @@ -18,21 +18,9 @@ import tensorflow as tf from tensorflow.keras import layers - -def convnext_exists_tf_version(): - tf_major_version = int(tf.__version__.split(".")[0]) - tf_minor_version = int(tf.__version__.split(".")[1]) - if 2 <= tf_major_version and 10 < tf_minor_version: - return True - return False - -if not convnext_exists_tf_version(): - raise ImportError(f"This code requires TensorFlow version 2.10 or higher. " - f"Please upgrade TensorFlow to use this code.") - -from tensorflow.keras.applications import convnext from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding from tensorflow_similarity.models import SimilarityModel +from tensorflow.keras.applications import convnext CONVNEXT_ARCHITECTURE = { @@ -44,7 +32,6 @@ def convnext_exists_tf_version(): } - def ConvNeXtSim( input_shape: tuple[int, int, int], embedding_size: int = 128, From c1ded829149b0fcc8f4837dae26c87db3c5dc385 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Fri, 8 Sep 2023 19:31:58 -0300 Subject: [PATCH 7/9] [lint] --- tensorflow_similarity/architectures/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_similarity/architectures/__init__.py b/tensorflow_similarity/architectures/__init__.py index 75282370..79ab603c 100644 --- a/tensorflow_similarity/architectures/__init__.py +++ b/tensorflow_similarity/architectures/__init__.py @@ -16,6 +16,7 @@ from .efficientnet import EfficientNetSim # noqa from .resnet18 import ResNet18Sim # noqa from .resnet50 import ResNet50Sim # noqa + try: from .convnext import ConvNeXtSim # noqa except ImportError: From 2b7a553a7dbf0ef0b6fe94c84389a79075b5ffb0 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Tue, 12 Sep 2023 10:08:57 -0300 Subject: [PATCH 8/9] [lint] black on test file --- tests/architectures/test_convnext.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/architectures/test_convnext.py b/tests/architectures/test_convnext.py index e61b03e3..f63edeee 100644 --- a/tests/architectures/test_convnext.py +++ b/tests/architectures/test_convnext.py @@ -15,12 +15,14 @@ TF_MAJOR_VERSION = int(tf.__version__.split(".")[0]) TF_MINOR_VERSION = int(tf.__version__.split(".")[1]) + def tf_version_check(major_version, minor_version): if TF_MAJOR_VERSION <= major_version and TF_MINOR_VERSION < minor_version: return True return False + def test_build_convnext_tiny_full(): input_layer = tf.keras.layers.Input((224, 224, 3)) output = convneXt.build_convnext("tiny", "imagenet", "full")(input_layer) @@ -35,13 +37,14 @@ def test_build_convnext_tiny_full(): total_layer_count += 1 if layer.trainable: trainable_layer_count += 1 - + expected_total_layer_count = 151 expected_trainable_layer_count = 151 assert total_layer_count == expected_total_layer_count assert trainable_layer_count == expected_trainable_layer_count + def test_build_convnext_small_partial(): input_layer = tf.keras.layers.Input((224, 224, 3)) output = convneXt.build_convnext("small", "imagenet", "partial")(input_layer) @@ -63,6 +66,7 @@ def test_build_convnext_small_partial(): assert total_layer_count == expected_total_layer_count assert trainable_layer_count == expected_trainable_layer_count + def test_build_convnext_base_frozen(): input_layer = tf.keras.layers.Input((224, 224, 3)) output = convneXt.build_convnext("base", "imagenet", "frozen")(input_layer) @@ -84,6 +88,7 @@ def test_build_convnext_base_frozen(): assert total_layer_count == expected_total_layer_count assert trainable_layer_count == expected_trainable_layer_count + def test_build_convnext_large_full(): input_layer = tf.keras.layers.Input((224, 224, 3)) output = convneXt.build_convnext("large", "imagenet", "full")(input_layer) @@ -103,4 +108,4 @@ def test_build_convnext_large_full(): expected_trainable_layer_count = 295 assert total_layer_count == expected_total_layer_count - assert trainable_layer_count == expected_trainable_layer_count \ No newline at end of file + assert trainable_layer_count == expected_trainable_layer_count From 3ac3017b1392ff3f3366b42ba17b5888a529eabf Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Tue, 12 Sep 2023 10:31:45 -0300 Subject: [PATCH 9/9] [lint] isort --- tensorflow_similarity/architectures/convnext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_similarity/architectures/convnext.py b/tensorflow_similarity/architectures/convnext.py index 4bc02960..3e02153f 100644 --- a/tensorflow_similarity/architectures/convnext.py +++ b/tensorflow_similarity/architectures/convnext.py @@ -18,10 +18,10 @@ import tensorflow as tf from tensorflow.keras import layers -from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding -from tensorflow_similarity.models import SimilarityModel from tensorflow.keras.applications import convnext +from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding +from tensorflow_similarity.models import SimilarityModel CONVNEXT_ARCHITECTURE = { "TINY": convnext.ConvNeXtTiny,