diff --git a/tensorflow_similarity/architectures/__init__.py b/tensorflow_similarity/architectures/__init__.py index f2dc3b1a..79ab603c 100644 --- a/tensorflow_similarity/architectures/__init__.py +++ b/tensorflow_similarity/architectures/__init__.py @@ -16,3 +16,8 @@ from .efficientnet import EfficientNetSim # noqa from .resnet18 import ResNet18Sim # noqa from .resnet50 import ResNet50Sim # noqa + +try: + from .convnext import ConvNeXtSim # noqa +except ImportError: + print("Warning: ConvNeXtSim not imported. This requires TensorFlow 2.10 or higher.") diff --git a/tensorflow_similarity/architectures/convnext.py b/tensorflow_similarity/architectures/convnext.py new file mode 100644 index 00000000..3e02153f --- /dev/null +++ b/tensorflow_similarity/architectures/convnext.py @@ -0,0 +1,143 @@ +# Copyright 2021 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"ConvNeXt backbone for similarity learning" +from __future__ import annotations + +import re + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.applications import convnext + +from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding +from tensorflow_similarity.models import SimilarityModel + +CONVNEXT_ARCHITECTURE = { + "TINY": convnext.ConvNeXtTiny, + "SMALL": convnext.ConvNeXtSmall, + "BASE": convnext.ConvNeXtBase, + "LARGE": convnext.ConvNeXtLarge, + "XLARGE": convnext.ConvNeXtXLarge, +} + + +def ConvNeXtSim( + input_shape: tuple[int, int, int], + embedding_size: int = 128, + variant: str = "BASE", + weights: str = "imagenet", + trainable: str = "frozen", + l2_norm: bool = True, + include_top: bool = True, + pooling: str = "gem", + gem_p: float = 3.0, +) -> SimilarityModel: + """ "Build an ConvNeXt Model backbone for similarity learning + [A ConvNet for the 2020s](https://arxiv.org/pdf/2201.03545.pdf) + Args: + input_shape: Size of the input image. Must match size of ConvNeXt version you use. + See below for version input size. + embedding_size: Size of the output embedding. Usually between 64 + and 512. Defaults to 128. + variant: Which Variant of the ConvNeXt to use. Defaults to "BASE". + weights: Use pre-trained weights - the only available currently being + imagenet. Defaults to "imagenet". + trainable: Make the ConvNeXt backbone fully trainable or partially + trainable. + - "full" to make the entire backbone trainable, + - "partial" to only make the last 3 block trainable + - "frozen" to make it not trainable. + l2_norm: If True and include_top is also True, then + tfsim.layers.MetricEmbedding is used as the last layer, otherwise + keras.layers.Dense is used. This should be true when using cosine + distance. Defaults to True. + include_top: Whether to include the fully-connected layer at the top + of the network. Defaults to True. + pooling: Optional pooling mode for feature extraction when + include_top is False. Defaults to gem. + - None means that the output of the model will be the 4D tensor + output of the last convolutional layer. + - avg means that global average pooling will be applied to the + output of the last convolutional layer, and thus the output of the + model will be a 2D tensor. + - max means that global max pooling will be applied. + - gem means that global GeneralizedMeanPooling2D will be applied. + The gem_p param sets the contrast amount on the pooling. + gem_p: Sets the power in the GeneralizedMeanPooling2D layer. A value + of 1.0 is equivalent to GlobalMeanPooling2D, while larger values + will increase the contrast between activations within each feature + map, and a value of math.inf will be equivalent to MaxPool2d. + """ + inputs = layers.Input(shape=input_shape) + x = inputs + + if variant not in CONVNEXT_ARCHITECTURE: + raise ValueError("Unknown ConvNeXt variant. Valid TINY BASE LARGE SMALL XLARGE") + + x = build_convnext(variant, weights, trainable)(x) + + if pooling == "gem": + x = GeneralizedMeanPooling2D(p=gem_p, name="gem_pool")(x) + elif pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + if include_top and pooling is not None: + if l2_norm: + outputs = MetricEmbedding(embedding_size)(x) + else: + outputs = layers.Dense(embedding_size)(x) + else: + outputs = x + + return SimilarityModel(inputs, outputs) + + +def build_convnext(variant: str, weights: str | None = None, trainable: str = "full") -> tf.keras.Model: + """Build the requested ConvNeXt + + Args: + variant: Which Variant of the ConvNeXt to use. + weights: Use pre-trained weights - the only available currently being + imagenet. + trainable: Make the ConvNeXt backbone fully trainable or partially + trainable. + - "full" to make the entire backbone trainable, + - "partial" to only make the last 3 block trainable + - "frozen" to make it not trainable. + Returns: + The output layer of the convnext model + """ + convnext_fn = CONVNEXT_ARCHITECTURE[variant.upper()] + convnext = convnext_fn(weights=weights, include_top=False) + + if trainable == "full": + convnext.trainable = True + elif trainable == "partial": + convnext.trainable = True + for layer in convnext.layers: + # freeze all layeres befor the last 3 blocks + if not re.search("^block[5,6,7]|^top", layer.name): + layer.trainable = False + elif trainable == "frozen": + convnext.trainable = False + else: + raise ValueError(f"{trainable} is not a supported option for 'trainable'.") + + if weights: + for layer in convnext.layers: + if isinstance(layer, layers.experimental.SyncBatchNormalization): + layer.trainable = False + return convnext diff --git a/tests/architectures/test_convnext.py b/tests/architectures/test_convnext.py new file mode 100644 index 00000000..f63edeee --- /dev/null +++ b/tests/architectures/test_convnext.py @@ -0,0 +1,111 @@ +import re + +import pytest +import tensorflow as tf + + +MIN_TF_MAJOR_VERSION = 2 +MIN_TF_MINOR_VERSION = 10 + +major_version = tf.__version__.split(".")[0] +minor_version = tf.__version__.split(".")[1] + +convneXt = pytest.importorskip("tensorflow_similarity.architectures.convnext") + +TF_MAJOR_VERSION = int(tf.__version__.split(".")[0]) +TF_MINOR_VERSION = int(tf.__version__.split(".")[1]) + + +def tf_version_check(major_version, minor_version): + if TF_MAJOR_VERSION <= major_version and TF_MINOR_VERSION < minor_version: + return True + + return False + + +def test_build_convnext_tiny_full(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("tiny", "imagenet", "full")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_tiny" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 151 + expected_trainable_layer_count = 151 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + + +def test_build_convnext_small_partial(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("small", "imagenet", "partial")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_small" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 0 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + + +def test_build_convnext_base_frozen(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("base", "imagenet", "frozen")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_base" + assert not convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 0 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count + + +def test_build_convnext_large_full(): + input_layer = tf.keras.layers.Input((224, 224, 3)) + output = convneXt.build_convnext("large", "imagenet", "full")(input_layer) + + convnext = output._keras_history.layer + assert convnext.name == "convnext_large" + assert convnext.trainable + + total_layer_count = 0 + trainable_layer_count = 0 + for layer in convnext._self_tracked_trackables: + total_layer_count += 1 + if layer.trainable: + trainable_layer_count += 1 + + expected_total_layer_count = 295 + expected_trainable_layer_count = 295 + + assert total_layer_count == expected_total_layer_count + assert trainable_layer_count == expected_trainable_layer_count