From a089a8b7559be3e5d5a49b04df54997a9803cc4d Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 7 Aug 2024 17:57:31 -0700 Subject: [PATCH 01/33] Add VGG16 backbone (#1737) * Agg Vgg16 backbone * update names * update tests * update test * add image classifier * incorporate review comments * Update test case * update backbone test * add image classifier * classifier cleanup * code reformat * add vgg16 image classifier * make vgg generic * update doc string * update docstring * add classifier test * update tests * update docstring * address review comments * code reformat * update the configs * address review comments * fix task saved model test * update init * code reformatted --- keras_nlp/api/models/__init__.py | 3 + keras_nlp/src/models/image_classifier.py | 90 ++++++++++ keras_nlp/src/models/vgg/__init__.py | 13 ++ keras_nlp/src/models/vgg/vgg_backbone.py | 159 ++++++++++++++++++ keras_nlp/src/models/vgg/vgg_backbone_test.py | 48 ++++++ .../src/models/vgg/vgg_image_classifier.py | 124 ++++++++++++++ .../models/vgg/vgg_image_classifier_test.py | 61 +++++++ keras_nlp/src/tests/test_case.py | 30 ++-- 8 files changed, 514 insertions(+), 14 deletions(-) create mode 100644 keras_nlp/src/models/image_classifier.py create mode 100644 keras_nlp/src/models/vgg/__init__.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone_test.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 4fb3b3cf00..41f1a47284 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -129,6 +129,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -194,6 +195,8 @@ from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.src.models.task import Task +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py new file mode 100644 index 0000000000..f0cc031dbc --- /dev/null +++ b/keras_nlp/src/models/image_classifier.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.ImageClassifier") +class ImageClassifier(Task): + """Base class for all image classification tasks. + + `ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + image classification. `ImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageClassifier` task for training. + + The `ImageClassifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_nlp/src/models/vgg/__init__.py b/keras_nlp/src/models/vgg/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vgg/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py new file mode 100644 index 0000000000..497381c0fc --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -0,0 +1,159 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.VGGBackbone") +class VGGBackbone(Backbone): + """ + This class represents Keras Backbone of VGG model. + + This class implements a VGG backbone as described in [Very Deep + Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556)(ICLR 2015). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for + VGG19 this is [2, 2, 4, 4, 4]. + stackwise_num_filters: list of ints, filter size for convolutional + blocks per VGG block. For both VGG16 and VGG19 this is [ + 64, 128, 256, 512, 512]. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + pooling: bool, Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained VGG backbone. + model = keras_nlp.models.VGGBackbone.from_preset("vgg16") + model(input_data) + + # Randomly initialized VGG backbone with a custom config. + model = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + stackwise_num_filters, + include_rescaling, + input_image_shape=(224, 224, 3), + pooling="avg", + **kwargs, + ): + + # === Functional Model === + img_input = keras.layers.Input(shape=input_image_shape) + x = img_input + + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + for stack_index in range(len(stackwise_num_repeats) - 1): + x = apply_vgg_block( + x=x, + num_layers=stackwise_num_repeats[stack_index], + filters=stackwise_num_filters[stack_index], + kernel_size=(3, 3), + activation="relu", + padding="same", + max_pool=True, + name=f"block{stack_index + 1}", + ) + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_num_filters = stackwise_num_filters + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_num_filters": self.stackwise_num_filters, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_vgg_block( + x, + num_layers, + filters, + kernel_size, + activation, + padding, + max_pool, + name, +): + """ + Applies VGG block + Args: + x: Tensor, input tensor to pass through network + num_layers: int, number of CNN layers in the block + filters: int, filter size of each CNN layer in block + kernel_size: int (or) tuple, kernel size for CNN layer in block + activation: str (or) callable, activation function for each CNN layer in + block + padding: str (or) callable, padding function for each CNN layer in block + max_pool: bool, whether to add MaxPooling2D layer at end of block + name: str, name of the block + + Returns: + keras.KerasTensor + """ + for num in range(1, num_layers + 1): + x = layers.Conv2D( + filters, + kernel_size, + activation=activation, + padding=padding, + name=f"{name}_conv{num}", + )(x) + if max_pool: + x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x) + return x diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py new file mode 100644 index 0000000000..05ed33ba0f --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class VGGBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [2, 3, 3], + "stackwise_num_filters": [8, 64, 64], + "input_image_shape": (16, 16, 3), + "include_rescaling": False, + "pooling": "avg", + } + self.input_data = np.ones((2, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 64), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py new file mode 100644 index 0000000000..a26fbfbc30 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -0,0 +1,124 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_nlp_export("keras_nlp.models.VGGImageClassifier") +class VGGImageClassifier(ImageClassifier): + """VGG16 image classifier task model. + + Args: + backbone: A `keras_nlp.models.VGGBackbone` instance. + num_classes: int, number of classes to predict. + pooling: str, type of pooling layer. Must be one of "avg", "max". + activation: Optional `str` or callable, defaults to "softmax". The + activation function to use on the Dense layer. Set `activation=None` + to return the output logits. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Examples: + Train from preset + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.VGGImageClassifier.from_preset( + 'vgg_16_image_classifier') + classifier.fit(x=images, y=labels, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + + # Access backbone programmatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Custom backbone + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + + backbone = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + classifier = keras_nlp.models.VGGImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = VGGBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py new file mode 100644 index 0000000000..4a2573e496 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.tests.test_case import TestCase + + +class VGGImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 4, 4, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = VGGBackbone( + stackwise_num_repeats=[2, 4, 4], + stackwise_num_filters=[2, 16, 16], + input_image_shape=(4, 4, 3), + include_rescaling=False, + pooling="max", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 7e8e0cec95..fc1ce77e1e 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -419,20 +419,22 @@ def run_backbone_test( self.assertEqual(output[key].shape, expected_output_shape[key]) else: self.assertEqual(output.shape, expected_output_shape) - - # Check we can embed tokens eagerly. - output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) - - # Check variable length sequences. - if variable_length_data is None: - # If no variable length data passed, assume the second axis of all - # inputs is our sequence axis and create it ourselves. - variable_length_data = [ - tree.map_structure(lambda x: x[:, :seq_length, ...], input_data) - for seq_length in (2, 3, 4) - ] - for batch in variable_length_data: - backbone(batch) + if backbone.token_embedding is not None: + # Check we can embed tokens eagerly. + output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) + + # Check variable length sequences. + if variable_length_data is None: + # If no variable length data passed, assume the second axis of all + # inputs is our sequence axis and create it ourselves. + variable_length_data = [ + tree.map_structure( + lambda x: x[:, :seq_length, ...], input_data + ) + for seq_length in (2, 3, 4) + ] + for batch in variable_length_data: + backbone(batch) # Check compiled predict function. backbone.predict(input_data) From 73b7bad007a8c37a54512092c2b8bfe435d21c10 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:09:08 +0800 Subject: [PATCH 02/33] Add `ResNetBackbone` and `ResNetImageClassifier` (#1765) * Add ResNetV1 and ResNetV2 * Address comments --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/resnet/__init__.py | 13 + .../src/models/resnet/resnet_backbone.py | 544 ++++++++++++++++++ .../src/models/resnet/resnet_backbone_test.py | 75 +++ .../models/resnet/resnet_image_classifier.py | 131 +++++ .../resnet/resnet_image_classifier_test.py | 62 ++ keras_nlp/src/tests/test_case.py | 60 ++ keras_nlp/src/utils/keras_utils.py | 13 + 8 files changed, 902 insertions(+) create mode 100644 keras_nlp/src/models/resnet/__init__.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone_test.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 41f1a47284..783cfd5087 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -181,6 +181,10 @@ from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM diff --git a/keras_nlp/src/models/resnet/__init__.py b/keras_nlp/src/models/resnet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/resnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py new file mode 100644 index 0000000000..bec5ba60b5 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -0,0 +1,544 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +@keras_nlp_export("keras_nlp.models.ResNetBackbone") +class ResNetBackbone(Backbone): + """ResNet and ResNetV2 core network with hyperparameters. + + This class implements a ResNet backbone as described in [Deep Residual + Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( + CVPR 2016) and [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016). + + The difference in ResNet and ResNetV2 rests in the structure of their + individual building blocks. In ResNetV2, the batch normalization and + ReLU activation precede the convolution layers, as opposed to ResNet where + the batch normalization and ReLU activation are applied after the + convolution layers. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + stackwise_num_strides: list of ints. The number of strides for each + stack. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + pooling: `None` or str. Pooling mode for feature extraction. Defaults + to `"avg"`. + - `None` means that the output of the model will be the 4D tensor + from the last convolutional block. + - `avg` means that global average pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + - `max` means that global max pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained ResNet backbone. + model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") + model(input_data) + + # Randomly initialized ResNetV2 backbone with a custom config. + model = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + pooling="avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + stackwise_num_strides, + block_type, + use_pre_activation=False, + include_rescaling=True, + input_image_shape=(None, None, 3), + pooling="avg", + data_format=None, + dtype=None, + **kwargs, + ): + if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( + stackwise_num_filters + ) != len(stackwise_num_strides): + raise ValueError( + "The length of `stackwise_num_filters`, `stackwise_num_blocks` " + "and `stackwise_num_strides` must be the same. Received: " + f"stackwise_num_filters={stackwise_num_filters}, " + f"stackwise_num_blocks={stackwise_num_blocks}, " + f"stackwise_num_strides={stackwise_num_strides}" + ) + if stackwise_num_filters[0] != 64: + raise ValueError( + "The first element of `stackwise_num_filters` must be 64. " + f"Received: stackwise_num_filters={stackwise_num_filters}" + ) + if block_type not in ("basic_block", "bottleneck_block"): + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + version = "v1" if not use_pre_activation else "v2" + data_format = standardize_data_format(data_format) + bn_axis = -1 if data_format == "channels_last" else 1 + num_stacks = len(stackwise_num_filters) + + # === Functional Model === + image_input = layers.Input(shape=input_image_shape) + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + else: + x = image_input + + x = layers.Conv2D( + 64, + 7, + strides=2, + padding="same", + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name="conv1_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) + + x = layers.MaxPool2D( + 3, + strides=2, + padding="same", + data_format=data_format, + dtype=dtype, + name="pool1_pool", + )(x) + + for stack_index in range(num_stacks): + x = apply_stack( + x, + filters=stackwise_num_filters[stack_index], + blocks=stackwise_num_blocks[stack_index], + stride=stackwise_num_strides[stack_index], + block_type=block_type, + use_pre_activation=use_pre_activation, + first_shortcut=( + block_type == "bottleneck_block" or stack_index > 0 + ), + data_format=data_format, + dtype=dtype, + name=f"{version}_stack{stack_index}", + ) + + if use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) + + if pooling == "avg": + feature_map_output = layers.GlobalAveragePooling2D( + data_format=data_format, dtype=dtype + )(x) + elif pooling == "max": + feature_map_output = layers.GlobalMaxPooling2D( + data_format=data_format, dtype=dtype + )(x) + else: + feature_map_output = x + + super().__init__( + inputs=image_input, + outputs=feature_map_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_strides = stackwise_num_strides + self.block_type = block_type + self.use_pre_activation = use_pre_activation + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_basic_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=stride if not use_pre_activation else 1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + 1, + strides=stride if not use_pre_activation else 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_stack( + x, + filters, + blocks, + stride, + block_type, + use_pre_activation, + first_shortcut=True, + data_format=None, + dtype=None, + name=None, +): + """Applies a set of stacked residual blocks. + + Args: + x: Tensor. The input tensor to pass through the stack. + filters: int. The number of filters in a block. + blocks: int. The number of blocks in the stack. + stride: int. The stride length of the first layer in the first block. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet and ResNeXt. + first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `True`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the stack. + + Returns: + Output tensor for the stacked blocks. + """ + if name is None: + version = "v1" if not use_pre_activation else "v2" + name = f"{version}_stack" + + if block_type == "basic_block": + block_fn = apply_basic_block + elif block_type == "bottleneck_block": + block_fn = apply_bottleneck_block + else: + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + x = block_fn( + x, + filters, + stride=stride if not use_pre_activation else 1, + conv_shortcut=first_shortcut, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block1", + ) + for i in range(2, blocks): + x = block_fn( + x, + filters, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(i)}", + ) + x = block_fn( + x, + filters, + stride=1 if not use_pre_activation else stride, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(blocks)}", + ) + return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py new file mode 100644 index 0000000000..2113bcd131 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "input_image_shape": (None, None, 3), + "pooling": "avg", + } + self.input_size = (16, 16) + self.input_data = ops.ones((2, 16, 16, 3)) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + def test_backbone_basics(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": block_type, "use_pre_activation": use_pre_activation} + ) + self.run_vision_backbone_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + expected_output_shape=( + (2, 64) if block_type == "basic_block" else (2, 256) + ), + ) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + @pytest.mark.large + def test_saved_model(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + "input_image_shape": (16, 16, 3), + } + ) + self.run_model_saving_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py new file mode 100644 index 0000000000..02c8c78b27 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_nlp_export("keras_nlp.models.ResNetImageClassifier") +class ResNetImageClassifier(ImageClassifier): + """ResNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.ResNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + include_rescaling=False, + pooling="avg", + ) + classifier = keras_nlp.models.ResNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ResNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=self.backbone.dtype_policy, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py new file mode 100644 index 0000000000..bbbda72d64 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 16, 16, 3)) + self.labels = [0, 3] + self.backbone = ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + input_image_shape=(16, 16, 3), + include_rescaling=False, + pooling="avg", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index fc1ce77e1e..72653c8b83 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -457,6 +457,66 @@ def run_backbone_test( if run_quantization_check and has_quantization_support(): self.run_quantization_test(backbone, cls, init_kwargs, input_data) + def run_vision_backbone_test( + self, + cls, + init_kwargs, + input_data, + expected_output_shape, + variable_length_data=None, + run_mixed_precision_check=True, + run_quantization_check=True, + run_data_format_check=True, + ): + """Run basic tests for a vision backbone, including compilation.""" + can_run_data_format_check = True + if ( + keras.config.backend() == "tensorflow" + and not tf.config.list_physical_devices("GPU") + ): + # Never test the "channels_first" format on tensorflow CPU. + # Tensorflow lacks support for "channels_first" convolution. + can_run_data_format_check = False + + ori_data_format = keras.config.image_data_format() + keras.config.set_image_data_format("channels_last") + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Check data_format. We assume that `input_data` is in "channels_last" + # format. + if run_data_format_check and can_run_data_format_check: + keras.config.set_image_data_format("channels_first") + input_data_shape = ops.shape(input_data) + if len(input_data_shape) == 3: + input_data = ops.transpose(input_data, axes=(2, 0, 1)) + elif len(input_data_shape) == 4: + input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) + if "input_image_shape" in init_kwargs: + init_kwargs = init_kwargs.copy() + init_kwargs["input_image_shape"] = tuple( + reversed(init_kwargs["input_image_shape"]) + ) + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Restore the original `image_data_format`. + keras.config.set_image_data_format(ori_data_format) + def run_task_test( self, cls, diff --git a/keras_nlp/src/utils/keras_utils.py b/keras_nlp/src/utils/keras_utils.py index 0fb96ccffb..b37b74ad19 100644 --- a/keras_nlp/src/utils/keras_utils.py +++ b/keras_nlp/src/utils/keras_utils.py @@ -115,3 +115,16 @@ def assert_quantization_support(): "Quantization API requires Keras >= 3.4.0 to function " f"correctly. Received: '{keras.version()}'" ) + + +def standardize_data_format(data_format): + if data_format is None: + return keras.config.image_data_format() + data_format = str(data_format).lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "The `data_format` argument must be one of " + "{'channels_first', 'channels_last'}. " + f"Received: data_format={data_format}" + ) + return data_format From 26afc7e538927bbb8d588ab72ce50c3a6c1f89b5 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 14 Aug 2024 18:30:21 -0700 Subject: [PATCH 03/33] Add CSP DarkNet backbone and classifier (#1774) * Add CSP DarkNet * Add CSP DarkNet * snake_case function names * change use_depthwise to block_type --- keras_nlp/api/models/__init__.py | 6 + keras_nlp/src/models/csp_darknet/__init__.py | 13 + .../csp_darknet/csp_darknet_backbone.py | 410 ++++++++++++++++++ .../csp_darknet/csp_darknet_backbone_test.py | 50 +++ .../csp_darknet_image_classifier.py | 133 ++++++ .../csp_darknet_image_classifier_test.py | 65 +++ 6 files changed, 677 insertions(+) create mode 100644 keras_nlp/src/models/csp_darknet/__init__.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..aca1e28538 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -50,6 +50,12 @@ from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.causal_lm import CausalLM from keras_nlp.src.models.classifier import Classifier +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, ) diff --git a/keras_nlp/src/models/csp_darknet/__init__.py b/keras_nlp/src/models/csp_darknet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py new file mode 100644 index 0000000000..2745f61d01 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -0,0 +1,410 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") +class CSPDarkNetBackbone(Backbone): + """This class represents Keras Backbone of CSPDarkNet model. + + This class implements a CSPDarkNet backbone as described in + [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( + https://arxiv.org/abs/1911.11929). + + Args: + stackwise_num_filters: A list of ints, filter size for each dark + level in the model. + stackwise_depth: A list of ints, the depth for each dark level in the + model. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.CSPDarkNetBackbone.from_preset( + "csp_darknet_tiny_imagenet" + ) + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_depth, + include_rescaling, + block_type="basic_block", + input_image_shape=(224, 224, 3), + **kwargs, + ): + # === Functional Model === + apply_ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "depthwise_block" + else apply_darknet_conv_block + ) + base_channels = stackwise_num_filters[0] // 2 + + image_input = layers.Input(shape=input_image_shape) + x = image_input + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + + x = apply_focus(name="stem_focus")(x) + x = apply_darknet_conv_block( + base_channels, kernel_size=3, strides=1, name="stem_conv" + )(x) + for index, (channels, depth) in enumerate( + zip(stackwise_num_filters, stackwise_depth) + ): + x = apply_ConvBlock( + channels, + kernel_size=3, + strides=2, + name=f"dark{index + 2}_conv", + )(x) + + if index == len(stackwise_depth) - 1: + x = apply_spatial_pyramid_pooling_bottleneck( + channels, + hidden_filters=channels // 2, + name=f"dark{index + 2}_spp", + )(x) + + x = apply_cross_stage_partial( + channels, + num_bottlenecks=depth, + block_type="basic_block", + residual=(index != len(stackwise_depth) - 1), + name=f"dark{index + 2}_csp", + )(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_depth = stackwise_depth + self.include_rescaling = include_rescaling + self.block_type = block_type + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_depth": self.stackwise_depth, + "include_rescaling": self.include_rescaling, + "block_type": self.block_type, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_focus(name=None): + """A block used in CSPDarknet to focus information into channels of the + image. + + If the dimensions of a batch input is (batch_size, width, height, channels), + this layer converts the image into size (batch_size, width/2, height/2, + 4*channels). See [the original discussion on YoloV5 Focus Layer](https://github.com/ultralytics/yolov5/discussions/3181). + + Args: + name: the name for the lambda layer used in the block. + + Returns: + a function that takes an input Tensor representing a Focus layer. + """ + + def apply(x): + return layers.Concatenate(name=name)( + [ + x[..., ::2, ::2, :], + x[..., 1::2, ::2, :], + x[..., ::2, 1::2, :], + x[..., 1::2, 1::2, :], + ], + ) + + return apply + + +def apply_darknet_conv_block( + filters, kernel_size, strides, use_bias=False, activation="silu", name=None +): + """ + The basic conv block used in Darknet. Applies Conv2D followed by a + BatchNorm. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + use_bias: Boolean, whether the layer uses a bias vector. + activation: the activation applied after the BatchNorm layer. One of + "silu", "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.Conv2D( + filters, + kernel_size, + strides, + padding="same", + use_bias=use_bias, + name=name + "_conv", + )(inputs) + + x = layers.BatchNormalization(name=name + "_bn")(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.silu(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + return x + + return apply + + +def apply_darknet_conv_block_depthwise( + filters, kernel_size, strides, activation="silu", name=None +): + """ + The depthwise conv block used in CSPDarknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.DepthwiseConv2D( + kernel_size, strides, padding="same", use_bias=False + )(inputs) + x = layers.BatchNormalization()(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.swish(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + x = apply_darknet_conv_block( + filters, kernel_size=1, strides=1, activation=activation + )(x) + + return x + + return apply + + +def apply_spatial_pyramid_pooling_bottleneck( + filters, + hidden_filters=None, + kernel_sizes=(5, 9, 13), + activation="silu", + name=None, +): + """ + Spatial pyramid pooling layer used in YOLOv3-SPP + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + hidden_filters: Integer, the dimensionality of the intermediate + bottleneck space (i.e. the number of output filters in the + bottleneck convolution). If None, it will be equal to filters. + Defaults to None. + kernel_sizes: A list or tuple representing all the pool sizes used for + the pooling layers, defaults to (5, 9, 13). + activation: Activation for the conv layers, defaults to "silu". + name: the prefix for the layer names used in the block. + + Returns: + a function that takes an input Tensor representing an + SpatialPyramidPoolingBottleneck. + """ + if name is None: + name = f"spp{keras.backend.get_uid('spp')}" + + if hidden_filters is None: + hidden_filters = filters + + def apply(x): + x = apply_darknet_conv_block( + hidden_filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(x) + x = [x] + + for kernel_size in kernel_sizes: + x.append( + layers.MaxPooling2D( + kernel_size, + strides=1, + padding="same", + name=f"{name}_maxpool_{kernel_size}", + )(x[0]) + ) + + x = layers.Concatenate(name=f"{name}_concat")(x) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(x) + + return x + + return apply + + +def apply_cross_stage_partial( + filters, + num_bottlenecks, + residual=True, + block_type="basic_block", + activation="silu", + name=None, +): + """A block used in Cross Stage Partial Darknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + num_bottlenecks: an integer representing the number of blocks added in + the layer bottleneck. + residual: a boolean representing whether the value tensor before the + bottleneck should be added to the output of the bottleneck as a + residual, defaults to True. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + """ + + if name is None: + name = f"cross_stage_partial_{keras.backend.get_uid('cross_stage_partial')}" + + def apply(inputs): + hidden_channels = filters // 2 + ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "basic_block" + else apply_darknet_conv_block + ) + + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(inputs) + + x2 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(inputs) + + for i in range(num_bottlenecks): + residual_x = x1 + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv1", + )(x1) + x1 = ConvBlock( + hidden_channels, + kernel_size=3, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv2", + )(x1) + if residual: + x1 = layers.Add(name=f"{name}_bottleneck_{i}_add")( + [residual_x, x1] + ) + + x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv3", + )(x) + + return x + + return apply diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py new file mode 100644 index 0000000000..aaad4fe515 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_depth": [1, 3, 3, 1], + "include_rescaling": False, + "block_type": "basic_block", + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 256), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py new file mode 100644 index 0000000000..6b013bdcc0 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetImageClassifier") +class CSPDarkNetImageClassifier(ImageClassifier): + """CSPDarkNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.CSPDarkNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.CSPDarkNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = CSPDarkNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py new file mode 100644 index 0000000000..a07bb017a3 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -0,0 +1,65 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = CSPDarkNetBackbone( + stackwise_num_filters=[2, 16, 16], + stackwise_depth=[1, 3, 3, 1], + include_rescaling=False, + block_type="basic_block", + input_image_shape=(16, 16, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 00ab4d5c4d0350872a64e9a42ad22cf4cb3a43c2 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:29:57 +0800 Subject: [PATCH 04/33] Add `FeaturePyramidBackbone` and port weights from `timm` for `ResNetBackbone` (#1769) * Add FeaturePyramidBackbone and update ResNetBackbone * Simplify the implementation * Fix CI * Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone * Add conversion implementation * Update docstrings * Address comments --- keras_nlp/api/models/__init__.py | 1 + keras_nlp/src/models/backbone.py | 3 + .../src/models/feature_pyramid_backbone.py | 73 +++++ .../src/models/resnet/resnet_backbone.py | 252 +++++++++++------- .../src/models/resnet/resnet_backbone_test.py | 25 +- .../models/resnet/resnet_image_classifier.py | 7 +- .../resnet/resnet_image_classifier_test.py | 4 + keras_nlp/src/utils/preset_utils.py | 4 + keras_nlp/src/utils/timm/__init__.py | 13 + keras_nlp/src/utils/timm/convert.py | 37 +++ keras_nlp/src/utils/timm/convert_resnet.py | 171 ++++++++++++ .../src/utils/timm/convert_resnet_test.py | 28 ++ .../utils/transformers/safetensor_utils.py | 4 +- 13 files changed, 524 insertions(+), 98 deletions(-) create mode 100644 keras_nlp/src/models/feature_pyramid_backbone.py create mode 100644 keras_nlp/src/utils/timm/__init__.py create mode 100644 keras_nlp/src/utils/timm/convert.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index aca1e28538..e079aa7c9e 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -112,6 +112,7 @@ ) from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index a58072dfce..0f41c63c81 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -30,6 +30,7 @@ from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.timm.convert import load_timm_backbone from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from if format == "transformers": return load_transformers_backbone(cls, preset, load_weights) + elif format == "timm": + return load_timm_backbone(cls, preset, load_weights, **kwargs) preset_cls = check_config_class(preset) if not issubclass(preset_cls, cls): diff --git a/keras_nlp/src/models/feature_pyramid_backbone.py b/keras_nlp/src/models/feature_pyramid_backbone.py new file mode 100644 index 0000000000..989d9fbd64 --- /dev/null +++ b/keras_nlp/src/models/feature_pyramid_backbone.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone") +class FeaturePyramidBackbone(Backbone): + """A backbone with feature pyramid outputs. + + `FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs` + property for accessing the feature pyramid outputs of the model. Subclassers + should set the `pyramid_outputs` property during the model constructor. + + Example: + + ```python + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + + # Convert to feature pyramid output format using ResNet. + backbone = ResNetBackbone.from_preset("resnet50") + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"] + ``` + """ + + @property + def pyramid_outputs(self): + """A dict for feature pyramid outputs. + + The key is a string represents the name of the feature output and the + value is a `keras.KerasTensor`. A typical feature pyramid has multiple + levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale + `Pn` represents a feature map `2^n` times smaller in width and height + than the inputs. + """ + return getattr(self, "_pyramid_outputs", {}) + + @pyramid_outputs.setter + def pyramid_outputs(self, value): + if not isinstance(value, dict): + raise TypeError( + "`pyramid_outputs` must be a dictionary. " + f"Received: value={value} of type {type(value)}" + ) + for k, v in value.items(): + if not isinstance(k, str): + raise TypeError( + "The key of `pyramid_outputs` must be a string. " + f"Received: key={k} of type {type(k)}" + ) + if not isinstance(v, keras.KerasTensor): + raise TypeError( + "The value of `pyramid_outputs` must be a " + "`keras.KerasTensor`. " + f"Received: value={v} of type {type(v)}" + ) + self._pyramid_outputs = value diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index bec5ba60b5..0f4d7c139a 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -13,20 +13,23 @@ # limitations under the License. import keras from keras import layers +from keras import ops from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.utils.keras_utils import standardize_data_format @keras_nlp_export("keras_nlp.models.ResNetBackbone") -class ResNetBackbone(Backbone): +class ResNetBackbone(FeaturePyramidBackbone): """ResNet and ResNetV2 core network with hyperparameters. This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( - CVPR 2016) and [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016). + CVPR 2016), [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + improved training procedure in timm](https://arxiv.org/abs/2110.00476)( + NeurIPS 2021 Workshop). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -34,6 +37,9 @@ class ResNetBackbone(Backbone): the batch normalization and ReLU activation are applied after the convolution layers. + Note that `ResNetBackbone` expects the inputs to be images with a value + range of `[0, 255]` when `include_rescaling=True`. + Args: stackwise_num_filters: list of ints. The number of filters for each stack. @@ -46,8 +52,8 @@ class ResNetBackbone(Backbone): use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using - `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to - `True`. + `Rescaling` and `Normalization` layers. If `False`, do nothing. + Defaults to `True`. input_image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults @@ -70,11 +76,11 @@ class ResNetBackbone(Backbone): `~/.keras/keras.json`. If you never set it, then it will be `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the models computations and weights. + to use for the model's computations and weights. Examples: ```python - input_data = np.ones((2, 224, 224, 3), dtype="float32") + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") @@ -136,34 +142,66 @@ def __init__( image_input = layers.Input(shape=input_image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + x = layers.Normalization( + axis=bn_axis, + mean=(0.485, 0.456, 0.406), + variance=(0.229**2, 0.224**2, 0.225**2), + dtype=dtype, + name="normalization", + )(x) else: x = image_input + # The padding between torch and tensorflow/jax differs when `strides>1`. + # Therefore, we need to manually pad the tensor. + x = layers.ZeroPadding2D( + 3, + data_format=data_format, + dtype=dtype, + name="conv1_pad", + )(x) x = layers.Conv2D( 64, 7, strides=2, - padding="same", data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name="conv1_conv", )(x) if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) - x = layers.MaxPool2D( + if use_pre_activation: + # A workaround for ResNetV2: we need -inf padding to prevent zeros + # from being the max values in the following `MaxPooling2D`. + pad_width = [[1, 1], [1, 1]] + if data_format == "channels_last": + pad_width += [[0, 0]] + else: + pad_width = [[0, 0]] + pad_width + pad_width = [[0, 0]] + pad_width + x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf")) + else: + x = layers.ZeroPadding2D( + 1, data_format=data_format, dtype=dtype, name="pool1_pad" + )(x) + x = layers.MaxPooling2D( 3, strides=2, - padding="same", data_format=data_format, dtype=dtype, name="pool1_pool", )(x) + pyramid_outputs = {} for stack_index in range(num_stacks): x = apply_stack( x, @@ -179,10 +217,15 @@ def __init__( dtype=dtype, name=f"{version}_stack{stack_index}", ) + pyramid_outputs[f"P{stack_index + 2}"] = x if use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="post_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) @@ -213,18 +256,23 @@ def __init__( self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling + self.pyramid_outputs = pyramid_outputs def get_config(self): - return { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_strides": self.stackwise_num_strides, - "block_type": self.block_type, - "use_pre_activation": self.use_pre_activation, - "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, - "pooling": self.pooling, - } + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + ) + return config def apply_basic_block( @@ -269,68 +317,81 @@ def apply_basic_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=stride if not use_pre_activation else 1, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, + strides=1, padding="same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -381,79 +442,97 @@ def apply_bottleneck_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( 4 * filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x x = layers.Conv2D( filters, 1, - strides=stride if not use_pre_activation else 1, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( 4 * filters, 1, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_3_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -513,32 +592,21 @@ def apply_stack( '`block_type` must be either `"basic_block"` or ' f'`"bottleneck_block"`. Received block_type={block_type}.' ) - x = block_fn( - x, - filters, - stride=stride if not use_pre_activation else 1, - conv_shortcut=first_shortcut, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block1", - ) - for i in range(2, blocks): + for i in range(blocks): + if i == 0: + stride = stride + conv_shortcut = first_shortcut + else: + stride = 1 + conv_shortcut = False x = block_fn( x, filters, + stride=stride, + conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", ) - x = block_fn( - x, - filters, - stride=1 if not use_pre_activation else stride, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block{str(blocks)}", - ) return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 2113bcd131..6d3f774559 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,6 +14,7 @@ import pytest from absl.testing import parameterized +from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -29,8 +30,8 @@ def setUp(self): "input_image_shape": (None, None, 3), "pooling": "avg", } - self.input_size = (16, 16) - self.input_data = ops.ones((2, 16, 16, 3)) + self.input_size = 64 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( ("v1_basic", False, "basic_block"), @@ -52,6 +53,24 @@ def test_backbone_basics(self, use_pre_activation, block_type): ), ) + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": "basic_block", "use_pre_activation": False} + ) + backbone = ResNetBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** int(k[1:])) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), @@ -65,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (16, 16, 3), + "input_image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 02c8c78b27..815dc7fcca 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -28,6 +28,8 @@ class ResNetImageClassifier(ImageClassifier): activation: `None`, str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. + head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` where `x` is a tensor and `y` is a integer from `[0, num_classes)`. @@ -92,16 +94,19 @@ def __init__( backbone, num_classes, activation="softmax", + head_dtype=None, preprocessor=None, # adding this dummy arg for saved model test # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): + head_dtype = head_dtype or backbone.dtype_policy + # === Layers === self.backbone = backbone self.output_dense = keras.layers.Dense( num_classes, activation=activation, - dtype=self.backbone.dtype_policy, + dtype=head_dtype, name="predictions", ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index bbbda72d64..f3f63a14a1 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -53,6 +53,10 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + def test_head_dtype(self): + model = ResNetImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f797bf9f18..9e3f51c43a 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -544,6 +544,10 @@ def check_format(preset): if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( preset, SAFETENSOR_CONFIG_FILE ): + # Determine the format by parsing the config file. + config = load_config(preset, HF_CONFIG_FILE) + if "hf://timm" in preset or "architecture" in config: + return "timm" return "transformers" if not check_file_exists(preset, METADATA_FILE): diff --git a/keras_nlp/src/utils/timm/__init__.py b/keras_nlp/src/utils/timm/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/utils/timm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/utils/timm/convert.py b/keras_nlp/src/utils/timm/convert.py new file mode 100644 index 0000000000..edfde3316b --- /dev/null +++ b/keras_nlp/src/utils/timm/convert.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert timm models to KerasNLP.""" + +from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone + + +def load_timm_backbone(cls, preset, load_weights, **kwargs): + """Load a timm model config and weights as a KerasNLP backbone. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + if cls is None: + raise ValueError("Backbone class is None") + if cls.__name__ == "ResNetBackbone": + return load_resnet_backbone(cls, preset, load_weights, **kwargs) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py new file mode 100644 index 0000000000..de2224eb9e --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -0,0 +1,171 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "resnetv2_" in timm_architecture: + use_pre_activation = True + else: + use_pre_activation = False + + if timm_architecture == "resnet18": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "basic_block" + elif timm_architecture == "resnet26": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "bottleneck_block" + elif timm_architecture == "resnet34": + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "basic_block" + elif timm_architecture in ("resnet50", "resnetv2_50"): + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet101", "resnetv2_101"): + stackwise_num_blocks = [3, 4, 23, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet152", "resnetv2_152"): + stackwise_num_blocks = [3, 8, 36, 3] + block_type = "bottleneck_block" + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=stackwise_num_blocks, + stackwise_num_strides=[1, 2, 2, 2], + block_type=block_type, + use_pre_activation=use_pre_activation, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + version = "v1" if not backbone.use_pre_activation else "v2" + block_type = backbone.block_type + + # Stem + if version == "v1": + port_conv2d("conv1_conv", "conv1") + port_batch_normalization("conv1_bn", "bn1") + else: + port_conv2d("conv1_conv", "stem.conv") + + # Stages + num_stacks = len(backbone.stackwise_num_filters) + for stack_index in range(num_stacks): + for block_idx in range(backbone.stackwise_num_blocks[stack_index]): + if version == "v1": + keras_name = f"v1_stack{stack_index}_block{block_idx}" + hf_name = f"layer{stack_index+1}.{block_idx}" + else: + keras_name = f"v2_stack{stack_index}_block{block_idx}" + hf_name = f"stages.{stack_index}.blocks.{block_idx}" + + if version == "v1": + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.0" + ) + port_batch_normalization( + f"{keras_name}_0_bn", f"{hf_name}.downsample.1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") + if block_type == "bottleneck_block": + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + port_batch_normalization( + f"{keras_name}_3_bn", f"{hf_name}.bn3" + ) + else: + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" + ) + port_batch_normalization( + f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization( + f"{keras_name}_1_bn", f"{hf_name}.norm2" + ) + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + if block_type == "bottleneck_block": + port_batch_normalization( + f"{keras_name}_2_bn", f"{hf_name}.norm3" + ) + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + + # Post + if version == "v2": + port_batch_normalization("post_bn", "norm") + + # Rebuild normalization layer with pretrained mean & std + mean = timm_config["pretrained_cfg"]["mean"] + std = timm_config["pretrained_cfg"]["std"] + normalization_layer = backbone.get_layer("normalization") + normalization_layer.input_mean = mean + normalization_layer.input_variance = [s**2 for s in std] + normalization_layer.build(normalization_layer._build_input_shape) + + +def load_resnet_backbone(cls, preset, load_weights, **kwargs): + timm_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(timm_config) + backbone = cls(**keras_config, **kwargs) + if load_weights: + jax_memory_cleanup(backbone) + # Use prefix="" to avoid using `get_prefixed_key`. + with SafetensorLoader(preset, prefix="") as loader: + convert_weights(backbone, loader, timm_config) + return backbone diff --git a/keras_nlp/src/utils/timm/convert_resnet_test.py b/keras_nlp/src/utils/timm/convert_resnet_test.py new file mode 100644 index 0000000000..a30bee46af --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class TimmResNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_resnet18_preset(self): + model = ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 512)) + + # TODO: compare numerics with timm model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 40ef473ff3..2fbd7e1aba 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -26,7 +26,7 @@ class SafetensorLoader(contextlib.ExitStack): - def __init__(self, preset): + def __init__(self, preset, prefix=None): super().__init__() if safetensors is None: @@ -42,7 +42,7 @@ def __init__(self, preset): else: self.safetensor_config = None self.safetensor_files = {} - self.prefix = None + self.prefix = prefix def get_prefixed_key(self, hf_weight_key, dict_like): """ From 9860756f183cc4ad9247bc29b6c0ee55ec2db6fc Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 17:38:39 -0700 Subject: [PATCH 05/33] Add DenseNet (#1775) * Add DenseNet * fix testcase * address comments * nit * fix lint errors * move description --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/densenet/__init__.py | 13 ++ .../src/models/densenet/densenet_backbone.py | 210 ++++++++++++++++++ .../models/densenet/densenet_backbone_test.py | 48 ++++ .../densenet/densenet_image_classifier.py | 131 +++++++++++ .../densenet_image_classifier_test.py | 63 ++++++ 6 files changed, 469 insertions(+) create mode 100644 keras_nlp/src/models/densenet/__init__.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone_test.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e079aa7c9e..bf5cc28060 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -74,6 +74,10 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_nlp/src/models/densenet/__init__.py b/keras_nlp/src/models/densenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/densenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py new file mode 100644 index 0000000000..8456fbcee6 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -0,0 +1,210 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_AXIS = 3 +BN_EPSILON = 1.001e-5 + + +@keras_nlp_export("keras_nlp.models.DenseNetBackbone") +class DenseNetBackbone(Backbone): + """Instantiates the DenseNet architecture. + + This class implements a DenseNet backbone as described in + [Densely Connected Convolutional Networks (CVPR 2017)]( + https://arxiv.org/abs/1608.06993 + ). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per dense block. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + input_image_shape: optional shape tuple, defaults to (224, 224, 3). + compression_ratio: float, compression rate at transition layers, + defaults to 0.5. + growth_rate: int, number of filters added by each dense block, + defaults to 32 + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.DenseNetBackbone.from_preset("densenet121_imagenet") + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + include_rescaling=True, + input_image_shape=(224, 224, 3), + compression_ratio=0.5, + growth_rate=32, + **kwargs, + ): + # === Functional Model === + image_input = keras.layers.Input(shape=input_image_shape) + + x = image_input + if include_rescaling: + x = keras.layers.Rescaling(1 / 255.0)(x) + + x = keras.layers.Conv2D( + 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + )(x) + x = keras.layers.Activation("relu", name="conv1_relu")(x) + x = keras.layers.MaxPooling2D( + 3, strides=2, padding="same", name="pool1" + )(x) + + for stack_index in range(len(stackwise_num_repeats) - 1): + index = stack_index + 2 + x = apply_dense_block( + x, + stackwise_num_repeats[stack_index], + growth_rate, + name=f"conv{index}", + ) + x = apply_transition_block( + x, compression_ratio, name=f"pool{index}" + ) + + x = apply_dense_block( + x, + stackwise_num_repeats[-1], + growth_rate, + name=f"conv{len(stackwise_num_repeats) + 1}", + ) + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + )(x) + x = keras.layers.Activation("relu", name="relu")(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.include_rescaling = include_rescaling + self.compression_ratio = compression_ratio + self.growth_rate = growth_rate + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_repeats": self.stackwise_num_repeats, + "include_rescaling": self.include_rescaling, + "compression_ratio": self.compression_ratio, + "growth_rate": self.growth_rate, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_dense_block(x, num_repeats, growth_rate, name=None): + """A dense block. + + Args: + x: input tensor. + num_repeats: int, number of repeated convolutional blocks. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"dense_block_{keras.backend.get_uid('dense_block')}" + + for i in range(num_repeats): + x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + return x + + +def apply_transition_block(x, compression_ratio, name=None): + """A transition block. + + Args: + x: input tensor. + compression_ratio: float, compression rate at transition layers. + name: string, block label. + """ + if name is None: + name = f"transition_block_{keras.backend.get_uid('transition_block')}" + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_relu")(x) + x = keras.layers.Conv2D( + int(x.shape[BN_AXIS] * compression_ratio), + 1, + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + return x + + +def apply_conv_block(x, growth_rate, name=None): + """A building block for a dense block. + + Args: + x: input tensor. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"conv_block_{keras.backend.get_uid('conv_block')}" + + shortcut = x + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) + x = keras.layers.Conv2D( + 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) + x = keras.layers.Conv2D( + growth_rate, + 3, + padding="same", + use_bias=False, + name=f"{name}_2_conv", + )(x) + x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + [shortcut, x] + ) + return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py new file mode 100644 index 0000000000..f0f8dac875 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [6, 12, 24, 16], + "include_rescaling": True, + "compression_ratio": 0.5, + "growth_rate": 32, + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 1024), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py new file mode 100644 index 0000000000..395e8f754d --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.DenseNetImageClassifier") +class DenseNetImageClassifier(ImageClassifier): + """DenseNet image classifier task model. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Args: + backbone: A `keras_nlp.models.DenseNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.DenseNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.DenseNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = DenseNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py new file mode 100644 index 0000000000..60d77d489c --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=True, + compression_ratio=0.5, + growth_rate=32, + input_image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From fd6f977b0136499ad4e1cf78cc8aea69fb3bfc27 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 20 Aug 2024 12:24:13 -0700 Subject: [PATCH 06/33] Add ViTDetBackbone (#1776) * add vit det vit_det_backbone * update docstring * code reformat * fix tests * address review comments * bump year on all files * address review comments * rename backbone * fix tests * change back to ViT * address review comments * update image shape --- keras_nlp/api/models/__init__.py | 1 + .../src/models/vit_det/vit_det_backbone.py | 204 +++++++ .../models/vit_det/vit_det_backbone_test.py | 54 ++ keras_nlp/src/models/vit_det/vit_layers.py | 565 ++++++++++++++++++ 4 files changed, 824 insertions(+) create mode 100644 keras_nlp/src/models/vit_det/vit_det_backbone.py create mode 100644 keras_nlp/src/models/vit_det/vit_det_backbone_test.py create mode 100644 keras_nlp/src/models/vit_det/vit_layers.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 6f7e08c520..1a6dd2e74f 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -212,6 +212,7 @@ from keras_nlp.src.models.task import Task from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/vit_det/vit_det_backbone.py b/keras_nlp/src/models/vit_det/vit_det_backbone.py new file mode 100644 index 0000000000..1e83e94b05 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_det_backbone.py @@ -0,0 +1,204 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.vit_det.vit_layers import AddPositionalEmbedding +from keras_nlp.src.models.vit_det.vit_layers import ViTDetPatchingAndEmbedding +from keras_nlp.src.models.vit_det.vit_layers import WindowedTransformerEncoder + + +@keras_nlp_export("keras_nlp.models.ViTDetBackbone") +class ViTDetBackbone(Backbone): + """An implementation of ViT image encoder. + + The ViTDetBackbone uses a windowed transformer encoder and relative + positional encodings. The code has been adapted from [Segment Anything + paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + hidden_size (int): The latent dimensionality to be projected + into in the output of each stacked windowed transformer encoder. + num_layers (int): The number of transformer encoder layers to + stack in the Vision Transformer. + intermediate_dim (int): The dimensionality of the hidden Dense + layer in the transformer MLP head. + num_heads (int): the number of heads to use in the + `MultiHeadAttentionWithRelativePE` layer of each transformer + encoder. + global_attention_layer_indices (list): Indexes for blocks using + global attention. + image_shape (tuple[int], optional): The size of the input image in + `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + include_rescaling (bool, optional): Whether to rescale the inputs. If + set to `True`, inputs will be passed through a + `Rescaling(1/255.0)` layer. Defaults to `False`. + patch_size (int, optional): the patch size to be supplied to the + Patching layer to turn input images into a flattened sequence of + patches. Defaults to `16`. + num_output_channels (int, optional): The number of channels (features) + in the output (image encodings). Defaults to `256`. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_abs_pos (bool, optional): Whether to add absolute positional + embeddings to the output patches. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `True`. + window_size (int, optional): The size of the window for windowed + attention in the transformer encoder blocks. Defaults to `14`. + layer_norm_epsilon (int, optional): The epsilon to use in the layer + normalization blocks in transformer encoder. Defaults to `1e-6`. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained ViTDetBackbone backbone. + model = keras_nlp.models.ViTDetBackbone.from_preset("vit_det") + model(input_data) + + # Randomly initialized ViTDetBackbone backbone with a custom config. + model = keras_nlp.models.ViTDetBackbone( + image_shape = (16, 16, 3), + patch_size = 2, + hidden_size = 4, + num_layers = 2, + global_attention_layer_indices = [2, 5, 8, 11], + intermediate_dim = 4 * 4, + num_heads = 2, + num_output_channels = 2, + window_size = 2, + ) + model(input_data) + ``` + """ + + def __init__( + self, + hidden_size, + num_layers, + intermediate_dim, + num_heads, + global_attention_layer_indices, + include_rescaling=True, + image_shape=(1024, 1024, 3), + patch_size=16, + num_output_channels=256, + use_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + layer_norm_epsilon=1e-6, + **kwargs + ): + # === Functional model === + img_input = keras.layers.Input(shape=image_shape) + # Check that the input image is well specified. + if img_input.shape[-3] is None or img_input.shape[-2] is None: + raise ValueError( + "Height and width of the image must be specified" + " in `image_shape`." + ) + if img_input.shape[-3] != img_input.shape[-2]: + raise ValueError( + "Input image must be square i.e. the height must" + " be equal to the width in the `image_shape`" + " tuple/tensor." + ) + img_size = img_input.shape[-3] + x = img_input + if include_rescaling: + # Use common rescaling strategy across keras_cv + x = keras.layers.Rescaling(1.0 / 255.0)(x) + # VITDet scales inputs based on the standard ImageNet mean/stddev. + x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / ( + ops.array([0.229, 0.224, 0.225], dtype=x.dtype) + ) + x = ViTDetPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=hidden_size, + )(x) + if use_abs_pos: + x = AddPositionalEmbedding(img_size, patch_size, hidden_size)(x) + for i in range(num_layers): + x = WindowedTransformerEncoder( + project_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + window_size=( + window_size + if i not in global_attention_layer_indices + else 0 + ), + input_size=(img_size // patch_size, img_size // patch_size), + )(x) + x = keras.layers.Conv2D( + filters=num_output_channels, kernel_size=1, use_bias=False + )(x) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) + x = keras.layers.Conv2D( + filters=num_output_channels, + kernel_size=3, + padding="same", + use_bias=False, + )(x) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.patch_size = patch_size + self.image_shape = image_shape + self.hidden_size = hidden_size + self.num_layers = num_layers + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.num_output_channels = num_output_channels + self.use_bias = use_bias + self.use_rel_pos = use_rel_pos + self.use_abs_pos = use_abs_pos + self.window_size = window_size + self.global_attention_layer_indices = global_attention_layer_indices + self.layer_norm_epsilon = layer_norm_epsilon + self.include_rescaling = include_rescaling + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "include_rescaling": self.include_rescaling, + "patch_size": self.patch_size, + "hidden_size": self.hidden_size, + "num_layers": self.num_layers, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "num_output_channels": self.num_output_channels, + "use_bias": self.use_bias, + "use_abs_pos": self.use_abs_pos, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "global_attention_layer_indices": self.global_attention_layer_indices, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_nlp/src/models/vit_det/vit_det_backbone_test.py b/keras_nlp/src/models/vit_det/vit_det_backbone_test.py new file mode 100644 index 0000000000..0ae277d122 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_det_backbone_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class ViTDetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "include_rescaling": True, + "image_shape": (16, 16, 3), + "patch_size": 2, + "hidden_size": 4, + "num_layers": 2, + "global_attention_layer_indices": [2, 5, 8, 11], + "intermediate_dim": 4 * 4, + "num_heads": 2, + "num_output_channels": 2, + "window_size": 2, + } + self.input_data = np.ones((1, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=ViTDetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(1, 8, 8, 2), + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTDetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vit_det/vit_layers.py b/keras_nlp/src/models/vit_det/vit_layers.py new file mode 100644 index 0000000000..e784595371 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_layers.py @@ -0,0 +1,565 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import keras +from keras import ops + + +class MLP(keras.layers.Layer): + """A MLP block with architecture. + + The MLP block implements `input_dim -> [intermediate_dim] -> + hidden_dim`. The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + intermediate_dim (int): The number of units in the hidden layers. + hidden_dim (int): The number of units in the output layer. + activation (str): Activation to use in the hidden layers. + Default is `"relu"`. + """ + + def __init__( + self, intermediate_dim, hidden_dim, activation="relu", **kwargs + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation = activation + h = [intermediate_dim] + self.dense_net = [] + for intermediate_dim in h: + self.dense_net.append(keras.layers.Dense(intermediate_dim)) + self.dense_net.append(keras.layers.Activation(activation)) + self.dense_net.append(keras.layers.Dense(hidden_dim)) + self.dense_net = keras.models.Sequential(self.dense_net) + + def build(self, input_shape): + self.dense_net.build(input_shape) + self.built = True + + def call(self, x): + return self.dense_net(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "hidden_dim": self.hidden_dim, + "activation": self.activation, + } + ) + return config + + +class AddRelativePositionalEmbedding(keras.layers.Layer): + def __init__(self, input_size, key_dim, **kwargs): + super().__init__(**kwargs) + self.input_size = input_size + self.key_dim = key_dim + self.rel_pos_h = self.add_weight( + name="rel_pos_h", + shape=(2 * self.input_size[0] - 1, self.key_dim), + initializer="zeros", + ) + self.rel_pos_w = self.add_weight( + name="rel_pos_w", + shape=(2 * self.input_size[1] - 1, self.key_dim), + initializer="zeros", + ) + self.built = True + + def _get_rel_pos(self, query_size, key_size, rel_pos): + """Get relative positional embeddings. + + Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + query_size (int): The number of features of the queries. + key_size (int): The number of features of the keys. + rel_pos (tensor): Relative positional embedding tensor. + + Returns: + tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = 2 * max(query_size, key_size) - 1 + if ops.shape(rel_pos)[0] != max_rel_dist: + rel_pos_resized = ops.image.resize( + image=ops.reshape( + rel_pos, + (1, ops.shape(rel_pos)[0], ops.shape(rel_pos)[1], 1), + ), + size=(max_rel_dist, ops.shape(rel_pos)[1]), + interpolation="bilinear", + ) + rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1)) + return rel_pos_resized + else: + rel_pos_resized = rel_pos + # Query coordinates + query_coordinates = ops.cast( + ops.arange(query_size), dtype=self.compute_dtype + )[:, None] * (max(key_size / query_size, 1.0)) + # Key coordinates + key_coordinates = ops.cast( + ops.arange(key_size), dtype=self.compute_dtype + )[None, :] * (max(query_size / key_size, 1.0)) + # Relative coordinates + relative_coordinates = (query_coordinates - key_coordinates) + ( + key_size - 1 + ) * max(query_size / key_size, 1.0) + relative_coordinates = ops.cast(relative_coordinates, dtype="int32") + return ops.take(rel_pos_resized, relative_coordinates, 0) + + def call(self, attention_map, queries, query_size, key_size): + """Calculate decomposed Relative Positional Embeddings + + The code has been adapted based on + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501 + + Args: + attention_map (tensor): Attention map. + queries (tensor): Queries in the attention layer with shape + `(batch, query_height * query_width, channels)`. + query_size (tuple[int, int]): Spatial sequence size of queries with + `(query_height, query_width)`. + key_size (tuple[int, int]): Spatial sequence size of keys with + `(key_height, key_width)`. + + Returns: + tensor: attention map with added relative positional embeddings. + """ + query_height, query_width = query_size[0], query_size[1] + key_height, key_width = key_size[0], key_size[1] + rel_heights = self._get_rel_pos( + query_height, key_height, self.rel_pos_h + ) + rel_widths = self._get_rel_pos(query_width, key_width, self.rel_pos_w) + shape = ops.shape(queries) + batch, channels = shape[0], shape[2] + rel_queries = ops.reshape( + queries, (batch, query_height, query_width, channels) + ) + rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights) + rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths) + attention_map = ops.reshape( + attention_map, + (batch, query_height, query_width, key_height, key_width), + ) + attention_map = attention_map + rel_heights[..., :, None] + attention_map = attention_map + rel_widths[..., None, :] + attention_map = ops.reshape( + attention_map, + (batch, query_height * query_width, key_height * key_width), + ) + return attention_map + + def get_config(self): + config = super().get_config() + config.update({"input_size": self.input_size, "key_dim": self.key_dim}) + return config + + +class MultiHeadAttentionWithRelativePE(keras.layers.Layer): + """Multi-head Attention block with relative position embeddings. + + The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + num_heads (int): Number of attention heads. + key_dim (int): Size of each attention head for query, key, and + value. + use_bias (bool, optional): Whether to use bias when projecting + the queries, keys, and values. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + embeddings or not. Defaults to `False`. + input_size (tuple[int, int], optional): Size of the input image. + Must be provided when using relative positional embeddings. + Defaults to `None`. + + Raises: + ValueError: When `input_size = None` with `use_rel_pos = True`. + """ + + def __init__( + self, + num_heads, + key_dim, + use_bias=True, + use_rel_pos=False, + input_size=None, + **kwargs + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.scale = self.key_dim**-0.5 + self.use_bias = use_bias + self.input_size = input_size + self.use_rel_pos = use_rel_pos + self.qkv = keras.layers.Dense( + key_dim * self.num_heads * 3, use_bias=self.use_bias + ) + self.projection = keras.layers.Dense(key_dim * self.num_heads) + if self.use_rel_pos: + if input_size is None: + raise ValueError( + "Input size must be provided if using relative " + "positional encoding." + ) + self.add_decomposed_reative_pe = AddRelativePositionalEmbedding( + self.input_size, self.key_dim + ) + + def build(self, input_shape=None): + self.qkv.build([self.key_dim * self.num_heads]) + self.projection.build([self.key_dim * self.num_heads]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + batch, height, width, channels = ops.shape(x) + qkv = ops.transpose( + ops.reshape( + self.qkv(x), + (batch, height * width, 3, self.num_heads, self.key_dim), + ), + axes=(2, 0, 3, 1, 4), + ) + qkv = ops.reshape( + qkv, (3, batch * self.num_heads, height * width, self.key_dim) + ) + queries, keys, values = ops.unstack(qkv, axis=0) + attention_map = (queries * self.scale) @ ops.transpose( + keys, axes=(0, 2, 1) + ) + if self.use_rel_pos: + attention_map = self.add_decomposed_reative_pe( + attention_map, + queries=queries, + query_size=(height, width), + key_size=(height, width), + ) + attention_map = ops.softmax(attention_map, axis=-1) + x = ops.reshape( + attention_map @ values, + (batch, self.num_heads, height, width, self.key_dim), + ) + x = ops.transpose(x, axes=(0, 2, 3, 1, 4)) + x = ops.reshape(x, (batch, height, width, channels)) + x = self.projection(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "input_size": self.input_size, + } + ) + return config + + +class WindowPartitioning(keras.layers.Layer): + def __init__(self, window_size, **kwargs): + super().__init__(**kwargs) + self.window_size = window_size + self.built = True + + def partition(self, x): + batch, height, width, channels = ops.shape(x) + pad_height = ( + self.window_size - height % self.window_size + ) % self.window_size + pad_width = ( + self.window_size - width % self.window_size + ) % self.window_size + if pad_height > 0 or pad_width > 0: + x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0))) + height_padded, width_padded = height + pad_height, width + pad_width + x = ops.reshape( + x, + ( + batch, + height_padded // self.window_size, + self.window_size, + width_padded // self.window_size, + self.window_size, + channels, + ), + ) + windows = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (-1, self.window_size, self.window_size, channels), + ) + return windows, (height_padded, width_padded) + + def unpartition(self, windows, height_width_padded, height_width): + height_padded, width_padded = height_width_padded + height, width = height_width + batch = ops.shape(windows)[0] // ( + (height_padded // self.window_size) + * (width_padded // self.window_size) + ) + x = ops.reshape( + windows, + ( + batch, + height_padded // self.window_size, + width_padded // self.window_size, + self.window_size, + self.window_size, + -1, + ), + ) + x = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (batch, height_padded, width_padded, -1), + ) + return x[:, :height, :width, :] + + def get_config(self): + config = super().get_config() + config.update({"window_size": self.window_size}) + return config + + +class WindowedTransformerEncoder(keras.layers.Layer): + """Implements windowed transformer encoder. + + Transformer blocks with support of window attention and residual + propagation blocks. The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + project_dim (int): the dimensionality of the projection of the + encoder, and output of the `MultiHeadAttention`. + intermediate_dim (int): the intermediate dimensionality of the MLP head + before projecting to `project_dim`. + num_heads (int): the number of heads for the `MultiHeadAttention` + layer. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `False`. + window_size (int, optional): Window size for windowed attention. + Defaults to `0`. + input_size (tuple[int, int], optional): Height and width of the input + image as a tuple of integers. Must be provided when using relative + positional embeddings. Defaults to `None`. + activation (str, optional): the activation function to apply in the + MLP head - should be a function. Defaults to `"gelu"`. + layer_norm_epsilon (float, optional): The epsilon to use in the layer + normalization layers. Defaults to `1e-6`. + """ + + def __init__( + self, + project_dim, + intermediate_dim, + num_heads, + use_bias=True, + use_rel_pos=False, + window_size=0, + input_size=None, + activation="gelu", + layer_norm_epsilon=1e-6, + **kwargs + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.use_bias = use_bias + self.input_size = input_size + self.activation = activation + self.layer_norm_epsilon = layer_norm_epsilon + self.window_size = window_size + self.use_rel_pos = use_rel_pos + + self.layer_norm1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.layer_norm2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.attention = MultiHeadAttentionWithRelativePE( + num_heads=self.num_heads, + key_dim=self.project_dim // self.num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + input_size=( + input_size if window_size == 0 else (window_size, window_size) + ), + ) + self.mlp_block = MLP( + intermediate_dim, + project_dim, + activation="gelu", + ) + self.window_partitioning = WindowPartitioning(window_size) + + def build(self, input_shape=None): + self.layer_norm1.build([None, None, None, self.project_dim]) + self.layer_norm2.build([None, None, None, self.project_dim]) + self.attention.build() + self.mlp_block.build([None, None, None, self.project_dim]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + shortcut = x + x = self.layer_norm1(x) + # Window Partition + if self.window_size > 0: + height, width = ops.shape(x)[1], ops.shape(x)[2] + x, height_width_padded = self.window_partitioning.partition(x) + + x = self.attention(x) + # Reverse Window Partition + if self.window_size > 0: + x = self.window_partitioning.unpartition( + x, + height_width_padded=height_width_padded, + height_width=(height, width), + ) + x = shortcut + x + x = x + self.mlp_block(self.layer_norm2(x)) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "input_size": self.input_size, + "activation": self.activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +class ViTDetPatchingAndEmbedding(keras.layers.Layer): + """ + Implements a image patch and embedding layer. + + Image to Patch Embedding using only a conv layer (without + layer normalization).The code has been adapted from [Segment Anything + paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + kernel_size (tuple[int, int], optional): Kernel size of the + projection layer. Defaults to `(16, 16)`. + strides (tuple, optional): Strides of the projection layer. + Defaults to `(16, 16)`. + embed_dim (int, optional): Number of filters to use in the + projection layer i.e. projection size. Defaults to `768`. + """ + + def __init__( + self, kernel_size=(16, 16), strides=(16, 16), embed_dim=768, **kwargs + ): + super().__init__(**kwargs) + + self.projection = keras.layers.Conv2D( + embed_dim, kernel_size=kernel_size, strides=strides + ) + self.kernel_size = kernel_size + self.strides = strides + self.embed_dim = embed_dim + + def build(self, input_shape): + self.projection.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return self.projection.compute_output_shape(input_shape) + + def call(self, x): + x = self.projection(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "kernel_size": self.kernel_size, + "strides": self.strides, + "embed_dim": self.embed_dim, + } + ) + return config + + +class AddPositionalEmbedding(keras.layers.Layer): + def __init__(self, img_size, patch_size, embed_dim, **kwargs): + super().__init__(**kwargs) + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.pos_embed = self.add_weight( + name="pos_embed", + shape=( + 1, + img_size // patch_size, + img_size // patch_size, + embed_dim, + ), + initializer="zeros", + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + return x + self.pos_embed + + def get_confg(self): + config = super().get_config() + config.update( + { + "img_size": self.img_size, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + } + ) + return config From fc485d6a259be75ce1103a3c114fe56d06cc5940 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 20 Aug 2024 12:26:55 -0700 Subject: [PATCH 07/33] Add Mix transformer (#1780) * Add MixTransformer * fix testcase * test changes and comments * lint fix * update config list * modify testcase for 2 layers --- keras_nlp/api/models/__init__.py | 6 + .../src/models/mix_transformer/__init__.py | 13 + .../mix_transformer_backbone.py | 181 +++++++++++ .../mix_transformer_backbone_test.py | 75 +++++ .../mix_transformer_classifier.py | 133 ++++++++ .../mix_transformer_classifier_test.py | 70 ++++ .../mix_transformer/mix_transformer_layers.py | 300 ++++++++++++++++++ 7 files changed, 778 insertions(+) create mode 100644 keras_nlp/src/models/mix_transformer/__init__.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_layers.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 1a6dd2e74f..c6d8ed7d32 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -165,6 +165,12 @@ MistralPreprocessor, ) from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mix_transformer/__init__.py b/keras_nlp/src/models/mix_transformer/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py new file mode 100644 index 0000000000..2cfe7f6761 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -0,0 +1,181 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + HierarchicalTransformerEncoder, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + OverlappingPatchingAndEmbedding, +) + + +@keras_nlp_export("keras_nlp.models.MiTBackbone") +class MiTBackbone(FeaturePyramidBackbone): + def __init__( + self, + depths, + num_layers, + blockwise_num_heads, + blockwise_sr_ratios, + end_value, + patch_sizes, + strides, + include_rescaling=True, + image_shape=(224, 224, 3), + hidden_dims=None, + **kwargs, + ): + """A Backbone implementing the MixTransformer. + + This architecture to be used as a backbone for the SegFormer + architecture [SegFormer: Simple and Efficient Design for Semantic + Segmentation with Transformers](https://arxiv.org/abs/2105.15203) + [Based on the TensorFlow implementation from DeepVision]( + https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) + + Args: + depths: The number of transformer encoders to be used per layer in the + network. + num_layers: int. The number of Transformer layers. + blockwise_num_heads: list of integers, the number of heads to use + in the attention computation for each layer. + blockwise_sr_ratios: list of integers, the sequence reduction + ratio to perform for each layer on the sequence before key and + value projections. If set to > 1, a `Conv2D` layer is used to + reduce the length of the sequence. + end_value: The end value of the sequence. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + image_shape: optional shape tuple, defaults to (224, 224, 3). + hidden_dims: the embedding dims per hierarchical layer, used as + the levels of the feature pyramid. + patch_sizes: list of integers, the patch_size to apply for each layer. + strides: list of integers, stride to apply for each layer. + + Examples: + + Using the class with a `backbone`: + + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet") + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + dpr = [x for x in np.linspace(0.0, end_value, sum(depths))] + + # === Layers === + cur = 0 + patch_embedding_layers = [] + transformer_blocks = [] + layer_norms = [] + + for i in range(num_layers): + patch_embed_layer = OverlappingPatchingAndEmbedding( + project_dim=hidden_dims[i], + patch_size=patch_sizes[i], + stride=strides[i], + name=f"patch_and_embed_{i}", + ) + patch_embedding_layers.append(patch_embed_layer) + + transformer_block = [ + HierarchicalTransformerEncoder( + project_dim=hidden_dims[i], + num_heads=blockwise_num_heads[i], + sr_ratio=blockwise_sr_ratios[i], + drop_prob=dpr[cur + k], + name=f"hierarchical_encoder_{i}_{k}", + ) + for k in range(depths[i]) + ] + transformer_blocks.append(transformer_block) + cur += depths[i] + layer_norms.append(keras.layers.LayerNormalization()) + + # === Functional Model === + image_input = keras.layers.Input(shape=image_shape) + x = image_input + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + pyramid_outputs = {} + for i in range(num_layers): + # Compute new height/width after the `proj` + # call in `OverlappingPatchingAndEmbedding` + stride = strides[i] + new_height, new_width = ( + int(ops.shape(x)[1] / stride), + int(ops.shape(x)[2] / stride), + ) + + x = patch_embedding_layers[i](x) + for blk in transformer_blocks[i]: + x = blk(x) + x = layer_norms[i](x) + x = keras.layers.Reshape( + (new_height, new_width, -1), name=f"output_level_{i}" + )(x) + pyramid_outputs[f"P{i + 1}"] = x + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.depths = depths + self.include_rescaling = include_rescaling + self.image_shape = image_shape + self.hidden_dims = hidden_dims + self.pyramid_outputs = pyramid_outputs + self.num_layers = num_layers + self.blockwise_num_heads = blockwise_num_heads + self.blockwise_sr_ratios = blockwise_sr_ratios + self.end_value = end_value + self.patch_sizes = patch_sizes + self.strides = strides + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "include_rescaling": self.include_rescaling, + "hidden_dims": self.hidden_dims, + "image_shape": self.image_shape, + "num_layers": self.num_layers, + "blockwise_num_heads": self.blockwise_num_heads, + "blockwise_sr_ratios": self.blockwise_sr_ratios, + "end_value": self.end_value, + "patch_sizes": self.patch_sizes, + "strides": self.strides, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py new file mode 100644 index 0000000000..4f1955297f --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import models + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2], + "include_rescaling": True, + "image_shape": (16, 16, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], + "end_value": 0.1, + "patch_sizes": [7, 3], + "strides": [4, 2], + } + self.input_size = 16 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 2, 2, 8), + run_quantization_check=False, + run_mixed_precision_check=False, + ) + + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs + backbone = MiTBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P1", "P2"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** (int(k[1:]) + 1)) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py new file mode 100644 index 0000000000..a9a51b63ba --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) + + +@keras_nlp_export("keras_nlp.models.MiTImageClassifier") +class MiTImageClassifier(ImageClassifier): + """MiTImageClassifier image classifier model. + + Args: + backbone: A `keras_nlp.models.MiTBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.MiTBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.MiTImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MiTBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py new file mode 100644 index 0000000000..57b0671be2 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MiTBackbone( + depths=[2, 2, 2, 2], + include_rescaling=True, + image_shape=(16, 16, 3), + hidden_dims=[4, 8], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + end_value=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py new file mode 100644 index 0000000000..53d99fe484 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py @@ -0,0 +1,300 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras +from keras import ops +from keras import random + + +class OverlappingPatchingAndEmbedding(keras.layers.Layer): + def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): + """Overlapping Patching and Embedding layer. + + Differs from `PatchingAndEmbedding` in that the patch size does not + affect the sequence length. It's fully derived from the `stride` + parameter. Additionally, no positional embedding is done + as part of the layer - only a projection using a `Conv2D` layer. + + Args: + project_dim: integer, the dimensionality of the projection. + Defaults to `32`. + patch_size: integer, the size of the patches to encode. + Defaults to `7`. + stride: integer, the stride to use for the patching before + projection. Defaults to `5`. + """ + super().__init__(**kwargs) + + self.project_dim = project_dim + self.patch_size = patch_size + self.stride = stride + + self.proj = keras.layers.Conv2D( + filters=project_dim, + kernel_size=patch_size, + strides=stride, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + x = self.proj(x) + # B, H, W, C + shape = x.shape + x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = self.norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "patch_size": self.patch_size, + "stride": self.stride, + } + ) + return config + + +class HierarchicalTransformerEncoder(keras.layers.Layer): + """Hierarchical transformer encoder block implementation as a Keras Layer. + + The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` + alternative for computational efficiency, and is meant to be used + within the SegFormer architecture. + + Args: + project_dim: integer, the dimensionality of the projection of the + encoder, and output of the `SegFormerMultiheadAttention` layer. + Due to the residual addition the input dimensionality has to be + equal to the output dimensionality. + num_heads: integer, the number of heads for the + `SegFormerMultiheadAttention` layer. + drop_prob: float, the probability of dropping a random + sample using the `DropPath` layer. Defaults to `0.0`. + layer_norm_epsilon: float, the epsilon for + `LayerNormalization` layers. Defaults to `1e-06` + sr_ratio: integer, the ratio to use within + `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` + layer is used to reduce the length of the sequence. Defaults to `1`. + """ + + def __init__( + self, + project_dim, + num_heads, + sr_ratio=1, + drop_prob=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.num_heads = num_heads + self.drop_prop = drop_prob + + self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.attn = SegFormerMultiheadAttention( + project_dim, num_heads, sr_ratio + ) + self.drop_path = DropPath(drop_prob) + self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.mlp = MixFFN( + channels=project_dim, + mid_channels=int(project_dim * 4), + ) + + def build(self, input_shape): + super().build(input_shape) + self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) + self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) + + def call(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "mlp": keras.saving.serialize_keras_object(self.mlp), + "project_dim": self.project_dim, + "num_heads": self.num_heads, + "drop_prop": self.drop_prop, + } + ) + return config + + +class MixFFN(keras.layers.Layer): + def __init__(self, channels, mid_channels): + super().__init__() + self.fc1 = keras.layers.Dense(mid_channels) + self.dwconv = keras.layers.DepthwiseConv2D( + kernel_size=3, + strides=1, + padding="same", + ) + self.fc2 = keras.layers.Dense(channels) + + def call(self, x): + x = self.fc1(x) + shape = ops.shape(x) + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) + B, C = shape[0], shape[2] + x = ops.reshape(x, (B, H, W, C)) + x = self.dwconv(x) + x = ops.reshape(x, (B, -1, C)) + x = ops.nn.gelu(x) + x = self.fc2(x) + return x + + +class SegFormerMultiheadAttention(keras.layers.Layer): + def __init__(self, project_dim, num_heads, sr_ratio): + """Efficient MultiHeadAttention implementation as a Keras layer. + + A huge bottleneck in scaling transformers is the self-attention layer + with an O(n^2) complexity. + + SegFormerMultiheadAttention performs a sequence reduction (SR) operation + with a given ratio, to reduce the sequence length before performing key + and value projections, reducing the O(n^2) complexity to O(n^2/R) where + R is the sequence reduction ratio. + + Args: + project_dim: integer, the dimensionality of the projection + of the `SegFormerMultiheadAttention` layer. + num_heads: integer, the number of heads to use in the + attention computation. + sr_ratio: integer, the sequence reduction ratio to perform + on the sequence before key and value projections. + """ + super().__init__() + self.num_heads = num_heads + self.sr_ratio = sr_ratio + self.scale = (project_dim // num_heads) ** -0.5 + self.q = keras.layers.Dense(project_dim) + self.k = keras.layers.Dense(project_dim) + self.v = keras.layers.Dense(project_dim) + self.proj = keras.layers.Dense(project_dim) + + if sr_ratio > 1: + self.sr = keras.layers.Conv2D( + filters=project_dim, + kernel_size=sr_ratio, + strides=sr_ratio, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + input_shape = ops.shape(x) + H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) + B, C = input_shape[0], input_shape[2] + + q = self.q(x) + q = ops.reshape( + q, + ( + input_shape[0], + input_shape[1], + self.num_heads, + input_shape[2] // self.num_heads, + ), + ) + q = ops.transpose(q, [0, 2, 1, 3]) + + if self.sr_ratio > 1: + x = ops.reshape( + ops.transpose(x, [0, 2, 1]), + (B, H, W, C), + ) + x = self.sr(x) + x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) + x = ops.transpose(x, [0, 2, 1]) + x = self.norm(x) + + k = self.k(x) + v = self.v(x) + + k = ops.transpose( + ops.reshape( + k, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + v = ops.transpose( + ops.reshape( + v, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale + attn = ops.nn.softmax(attn, axis=-1) + + attn = attn @ v + attn = ops.reshape( + ops.transpose(attn, [0, 2, 1, 3]), + [input_shape[0], input_shape[1], input_shape[2]], + ) + + x = self.proj(attn) + return x + + +class DropPath(keras.layers.Layer): + """Implements the DropPath layer. + + DropPath randomly drops samples during + training with a probability of `rate`. Note that this layer drops individual + samples within a batch and not the entire batch, whereas StochasticDepth + randomly drops the entire batch. + + Args: + rate: float, the probability of the residual branch being dropped. + seed: (Optional) integer. Used to create a random seed. + """ + + def __init__(self, rate=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self._seed_val = seed + self.seed = random.SeedGenerator(seed=seed) + + def call(self, x, training=None): + if self.rate == 0.0 or not training: + return x + else: + batch_size = x.shape[0] or ops.shape(x)[0] + drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1) + drop_map = ops.cast( + random.uniform(drop_map_shape, seed=self.seed) > self.rate, + x.dtype, + ) + x = x / (1.0 - self.rate) + x = x * drop_map + return x + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate, "seed": self._seed_val}) + return config From 2797851c259ce36bb51c6e93baeb3b282b152663 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 20 Aug 2024 23:00:43 -0700 Subject: [PATCH 08/33] update input_image_shape -> image_shape (#1785) * update input_image_shape -> image_shape * update docstring example * code reformat * update tests --- .../models/csp_darknet/csp_darknet_backbone.py | 10 +++++----- .../csp_darknet/csp_darknet_backbone_test.py | 2 +- .../csp_darknet/csp_darknet_image_classifier.py | 2 +- .../csp_darknet_image_classifier_test.py | 2 +- .../src/models/densenet/densenet_backbone.py | 10 +++++----- .../src/models/densenet/densenet_backbone_test.py | 2 +- .../models/densenet/densenet_image_classifier.py | 2 +- .../densenet/densenet_image_classifier_test.py | 2 +- keras_nlp/src/models/resnet/resnet_backbone.py | 10 +++++----- .../src/models/resnet/resnet_backbone_test.py | 4 ++-- .../models/resnet/resnet_image_classifier_test.py | 2 +- keras_nlp/src/models/vgg/vgg_backbone.py | 15 +++++++-------- keras_nlp/src/models/vgg/vgg_backbone_test.py | 2 +- keras_nlp/src/models/vgg/vgg_image_classifier.py | 2 +- .../src/models/vgg/vgg_image_classifier_test.py | 2 +- keras_nlp/src/tests/test_case.py | 6 +++--- 16 files changed, 37 insertions(+), 38 deletions(-) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py index 2745f61d01..607c6895ba 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -38,7 +38,7 @@ class CSPDarkNetBackbone(Backbone): Use `"depthwise_block"` for depthwise conv block `"basic_block"` for basic conv block. Defaults to "basic_block". - input_image_shape: tuple. The input shape without the batch size. + image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. Examples: @@ -67,7 +67,7 @@ def __init__( stackwise_depth, include_rescaling, block_type="basic_block", - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), **kwargs, ): # === Functional Model === @@ -78,7 +78,7 @@ def __init__( ) base_channels = stackwise_num_filters[0] // 2 - image_input = layers.Input(shape=input_image_shape) + image_input = layers.Input(shape=image_shape) x = image_input if include_rescaling: x = layers.Rescaling(scale=1 / 255.0)(x) @@ -119,7 +119,7 @@ def __init__( self.stackwise_depth = stackwise_depth self.include_rescaling = include_rescaling self.block_type = block_type - self.input_image_shape = input_image_shape + self.image_shape = image_shape def get_config(self): config = super().get_config() @@ -129,7 +129,7 @@ def get_config(self): "stackwise_depth": self.stackwise_depth, "include_rescaling": self.include_rescaling, "block_type": self.block_type, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, } ) return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py index aaad4fe515..857e06039d 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -28,7 +28,7 @@ def setUp(self): "stackwise_depth": [1, 3, 3, 1], "include_rescaling": False, "block_type": "basic_block", - "input_image_shape": (224, 224, 3), + "image_shape": (224, 224, 3), } self.input_data = np.ones((2, 224, 224, 3), dtype="float32") diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py index 6b013bdcc0..09a7022122 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -78,7 +78,7 @@ class CSPDarkNetImageClassifier(ImageClassifier): stackwise_depth=[3, 9, 9, 3], include_rescaling=False, block_type="basic_block", - input_image_shape = (224, 224, 3), + image_shape = (224, 224, 3), ) classifier = keras_nlp.models.CSPDarkNetImageClassifier( backbone=backbone, diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py index a07bb017a3..33261c25b6 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -33,7 +33,7 @@ def setUp(self): stackwise_depth=[1, 3, 3, 1], include_rescaling=False, block_type="basic_block", - input_image_shape=(16, 16, 3), + image_shape=(16, 16, 3), ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index 8456fbcee6..60a5b28849 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -35,7 +35,7 @@ class DenseNetBackbone(Backbone): include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - input_image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (224, 224, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. growth_rate: int, number of filters added by each dense block, @@ -62,13 +62,13 @@ def __init__( self, stackwise_num_repeats, include_rescaling=True, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), compression_ratio=0.5, growth_rate=32, **kwargs, ): # === Functional Model === - image_input = keras.layers.Input(shape=input_image_shape) + image_input = keras.layers.Input(shape=image_shape) x = image_input if include_rescaling: @@ -116,7 +116,7 @@ def __init__( self.include_rescaling = include_rescaling self.compression_ratio = compression_ratio self.growth_rate = growth_rate - self.input_image_shape = input_image_shape + self.image_shape = image_shape def get_config(self): config = super().get_config() @@ -126,7 +126,7 @@ def get_config(self): "include_rescaling": self.include_rescaling, "compression_ratio": self.compression_ratio, "growth_rate": self.growth_rate, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, } ) return config diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py index f0f8dac875..63f358035c 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone_test.py +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -26,7 +26,7 @@ def setUp(self): "include_rescaling": True, "compression_ratio": 0.5, "growth_rate": 32, - "input_image_shape": (224, 224, 3), + "image_shape": (224, 224, 3), } self.input_data = np.ones((2, 224, 224, 3), dtype="float32") diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py index 395e8f754d..130904be70 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -76,7 +76,7 @@ class DenseNetImageClassifier(ImageClassifier): stackwise_depth=[3, 9, 9, 3], include_rescaling=False, block_type="basic_block", - input_image_shape = (224, 224, 3), + image_shape = (224, 224, 3), ) classifier = keras_nlp.models.DenseNetImageClassifier( backbone=backbone, diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py index 60d77d489c..439a60008d 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -31,7 +31,7 @@ def setUp(self): include_rescaling=True, compression_ratio=0.5, growth_rate=32, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 0f4d7c139a..31698e0a1c 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -54,7 +54,7 @@ class ResNetBackbone(FeaturePyramidBackbone): include_rescaling: boolean. If `True`, rescale the input using `Rescaling` and `Normalization` layers. If `False`, do nothing. Defaults to `True`. - input_image_shape: tuple. The input shape without the batch size. + image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults to `"avg"`. @@ -107,7 +107,7 @@ def __init__( block_type, use_pre_activation=False, include_rescaling=True, - input_image_shape=(None, None, 3), + image_shape=(None, None, 3), pooling="avg", data_format=None, dtype=None, @@ -139,7 +139,7 @@ def __init__( num_stacks = len(stackwise_num_filters) # === Functional Model === - image_input = layers.Input(shape=input_image_shape) + image_input = layers.Input(shape=image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) x = layers.Normalization( @@ -254,7 +254,7 @@ def __init__( self.block_type = block_type self.use_pre_activation = use_pre_activation self.include_rescaling = include_rescaling - self.input_image_shape = input_image_shape + self.image_shape = image_shape self.pooling = pooling self.pyramid_outputs = pyramid_outputs @@ -268,7 +268,7 @@ def get_config(self): "block_type": self.block_type, "use_pre_activation": self.use_pre_activation, "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, "pooling": self.pooling, } ) diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 6d3f774559..a6a30362cd 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -27,7 +27,7 @@ def setUp(self): "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], - "input_image_shape": (None, None, 3), + "image_shape": (None, None, 3), "pooling": "avg", } self.input_size = 64 @@ -84,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (None, None, 3), + "image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index f3f63a14a1..893ec42487 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -31,7 +31,7 @@ def setUp(self): stackwise_num_strides=[1, 2, 2], block_type="basic_block", use_pre_activation=True, - input_image_shape=(16, 16, 3), + image_shape=(16, 16, 3), include_rescaling=False, pooling="avg", ) diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py index 497381c0fc..b215261fed 100644 --- a/keras_nlp/src/models/vgg/vgg_backbone.py +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -20,8 +20,7 @@ @keras_nlp_export("keras_nlp.models.VGGBackbone") class VGGBackbone(Backbone): - """ - This class represents Keras Backbone of VGG model. + """This class represents Keras Backbone of VGG model. This class implements a VGG backbone as described in [Very Deep Convolutional Networks for Large-Scale Image Recognition]( @@ -36,7 +35,7 @@ class VGGBackbone(Backbone): 64, 128, 256, 512, 512]. include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a `Rescaling(1/255.0)` layer. - input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). pooling: bool, Optional pooling mode for feature extraction when `include_top` is `False`. - `None` means that the output of the model will be @@ -61,7 +60,7 @@ class VGGBackbone(Backbone): model = keras_nlp.models.VGGBackbone( stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], - input_shape = (224, 224, 3), + image_shape = (224, 224, 3), include_rescaling = False, pooling = "avg", ) @@ -74,13 +73,13 @@ def __init__( stackwise_num_repeats, stackwise_num_filters, include_rescaling, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), pooling="avg", **kwargs, ): # === Functional Model === - img_input = keras.layers.Input(shape=input_image_shape) + img_input = keras.layers.Input(shape=image_shape) x = img_input if include_rescaling: @@ -107,7 +106,7 @@ def __init__( self.stackwise_num_repeats = stackwise_num_repeats self.stackwise_num_filters = stackwise_num_filters self.include_rescaling = include_rescaling - self.input_image_shape = input_image_shape + self.image_shape = image_shape self.pooling = pooling def get_config(self): @@ -115,7 +114,7 @@ def get_config(self): "stackwise_num_repeats": self.stackwise_num_repeats, "stackwise_num_filters": self.stackwise_num_filters, "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, "pooling": self.pooling, } diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py index 05ed33ba0f..d5521ca92d 100644 --- a/keras_nlp/src/models/vgg/vgg_backbone_test.py +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -24,7 +24,7 @@ def setUp(self): self.init_kwargs = { "stackwise_num_repeats": [2, 3, 3], "stackwise_num_filters": [8, 64, 64], - "input_image_shape": (16, 16, 3), + "image_shape": (16, 16, 3), "include_rescaling": False, "pooling": "avg", } diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py index a26fbfbc30..d849586ed8 100644 --- a/keras_nlp/src/models/vgg/vgg_image_classifier.py +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -65,7 +65,7 @@ class VGGImageClassifier(ImageClassifier): backbone = keras_nlp.models.VGGBackbone( stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], - input_shape = (224, 224, 3), + image_shape = (224, 224, 3), include_rescaling = False, pooling = "avg", ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py index 4a2573e496..20d855cb66 100644 --- a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -27,7 +27,7 @@ def setUp(self): self.backbone = VGGBackbone( stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], - input_image_shape=(4, 4, 3), + image_shape=(4, 4, 3), include_rescaling=False, pooling="max", ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 9a94d6357f..8e63bc19d9 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -501,10 +501,10 @@ def run_vision_backbone_test( input_data = ops.transpose(input_data, axes=(2, 0, 1)) elif len(input_data_shape) == 4: input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) - if "input_image_shape" in init_kwargs: + if "image_shape" in init_kwargs: init_kwargs = init_kwargs.copy() - init_kwargs["input_image_shape"] = tuple( - reversed(init_kwargs["input_image_shape"]) + init_kwargs["image_shape"] = tuple( + reversed(init_kwargs["image_shape"]) ) self.run_backbone_test( cls=cls, From 18f88803daa194b65914a77d6a25cade585aaf6c Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 21 Aug 2024 17:13:02 -0700 Subject: [PATCH 09/33] Create __init__.py (#1788) add missing __init__ file to vit_det --- keras_nlp/src/models/vit_det/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 keras_nlp/src/models/vit_det/__init__.py diff --git a/keras_nlp/src/models/vit_det/__init__.py b/keras_nlp/src/models/vit_det/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vit_det/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 2ee893ce7ceb292f98e6dc184faa74e959df8836 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:29:21 -0700 Subject: [PATCH 10/33] Hack package build script to rename to keras-hub (#1793) This is a temporary way to test out the keras-hub branch. - Does a global rename of all symbols during package build. - Registers the "old" name on symbol export for saving compat. - Adds a github action to publish every commit to keras-hub as a new package. - Removes our descriptions on PyPI temporarily, until we want to message this more broadly. --- .github/workflows/publish-hub-to-pypi.yml | 43 +++++++++++++++++++++++ keras_nlp/src/api_export.py | 5 +++ keras_nlp/src/utils/preset_utils.py | 3 +- pip_build.py | 19 ++++++---- setup.py | 7 ++-- 5 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/publish-hub-to-pypi.yml diff --git a/.github/workflows/publish-hub-to-pypi.yml b/.github/workflows/publish-hub-to-pypi.yml new file mode 100644 index 0000000000..838ca9b698 --- /dev/null +++ b/.github/workflows/publish-hub-to-pypi.yml @@ -0,0 +1,43 @@ +name: Publish Hub to PyPI + +on: + push: + branches: + - keras-hub + +permissions: + contents: read + +jobs: + build-and-publish: + name: Build and publish Hub to PyPI + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "::set-output name=dir::$(pip cache dir)" + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off + - name: Build a binary wheel and a source tarball + run: >- + python pip_build.py + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN_HUB }} diff --git a/keras_nlp/src/api_export.py b/keras_nlp/src/api_export.py index cfa3519ce9..93e7b54c2f 100644 --- a/keras_nlp/src/api_export.py +++ b/keras_nlp/src/api_export.py @@ -24,6 +24,11 @@ def maybe_register_serializable(symbol): if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"): + # We register twice, first with the old name, second with the new name, + # so loading still works under the old name. + # TODO replace compat_package_name with keras-nlp after rename. + compat_name = "compat_package_name" + keras.saving.register_keras_serializable(package=compat_name)(symbol) keras.saving.register_keras_serializable(package="keras_nlp")(symbol) diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 0277eb74a7..297e4bcb7f 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -99,7 +99,8 @@ def list_presets(cls): def list_subclasses(cls): """Find all registered subclasses of a class.""" - custom_objects = keras.saving.get_custom_objects().values() + # Deduplicate the lists, since we have to register object twice for compat. + custom_objects = set(keras.saving.get_custom_objects().values()) subclasses = [] for x in custom_objects: if inspect.isclass(x) and x != cls and issubclass(x, cls): diff --git a/pip_build.py b/pip_build.py index 7fed385c71..f14f24312b 100644 --- a/pip_build.py +++ b/pip_build.py @@ -36,7 +36,7 @@ import re import shutil -package = "keras_nlp" +package = "keras_hub" build_directory = "tmp_build_dir" dist_directory = "dist" to_copy = ["setup.py", "setup.cfg", "README.md"] @@ -48,15 +48,15 @@ def ignore_files(_, filenames): def export_version_string(version, is_nightly=False): """Export Version and Package Name.""" + date = datetime.datetime.now() + version += f".dev{date.strftime('%Y%m%d%H%M%S')}" if is_nightly: - date = datetime.datetime.now() - version += f".dev{date.strftime('%Y%m%d%H')}" - # Replaces `name="keras-nlp"` in `setup.py` with `keras-nlp-nightly` + # Replaces `name="keras-hub"` in `setup.py` with `keras-hub-nightly` with open("setup.py") as f: setup_contents = f.read() with open("setup.py", "w") as f: setup_contents = setup_contents.replace( - 'name="keras-nlp"', 'name="keras-nlp-nightly"' + 'name="keras-hub"', 'name="keras-hub-nightly"' ) f.write(setup_contents) @@ -78,11 +78,18 @@ def copy_source_to_build_directory(root_path): os.chdir(root_path) os.mkdir(build_directory) shutil.copytree( - package, os.path.join(build_directory, package), ignore=ignore_files + "keras_nlp", os.path.join(build_directory, package), ignore=ignore_files ) for fname in to_copy: shutil.copy(fname, os.path.join(f"{build_directory}", fname)) os.chdir(build_directory) + # TODO: remove all of this when our code is actually renamed in the repo. + os.system("grep -lR 'keras_nlp' . | xargs sed -i 's/keras_nlp/keras_hub/g'") + os.system("grep -lR 'keras-nlp' . | xargs sed -i 's/keras-nlp/keras-hub/g'") + os.system("grep -lR 'KerasNLP' . | xargs sed -i 's/KerasNLP/KerasHub/g'") + os.system( + "grep -lR 'compat_package_name' . | xargs sed -i 's/compat_package_name/keras_nlp/g'" + ) def build(root_path, is_nightly=False): diff --git a/setup.py b/setup.py index f664aa5a61..f2ec7b84be 100644 --- a/setup.py +++ b/setup.py @@ -44,11 +44,8 @@ def get_version(rel_path): setup( name="keras-nlp", - description=( - "Industry-strength Natural Language Processing extensions for Keras." - ), - long_description=README, - long_description_content_type="text/markdown", + description="🚧🚧🚧 Work in progress. 🚧🚧🚧 More details soon!", + long_description="🚧🚧🚧 Work in progress. 🚧🚧🚧 More details soon!", version=VERSION, url="https://github.com/keras-team/keras-nlp", author="Keras team", From fdf6b6bdc75249c8beff4317a768c14101c7699f Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 27 Aug 2024 02:30:08 +0800 Subject: [PATCH 11/33] Add CLIP and T5XXL for StableDiffusionV3 (#1790) * Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`. * Make CLIPTextEncoder as Backbone * Add `T5XXLPreprocessor` and remove `T5XXLTokenizer` Add `CLIPPreprocessor` * Use `tf = None` at the top * Replace manual implementation of `CLIPAttention` with `MultiHeadAttention` --- .../models/stable_diffusion_v3/__init__.py | 13 ++ .../stable_diffusion_v3/clip_encoder_block.py | 103 +++++++++++ .../stable_diffusion_v3/clip_preprocessor.py | 104 +++++++++++ .../clip_preprocessor_test.py | 78 ++++++++ .../stable_diffusion_v3/clip_text_encoder.py | 141 +++++++++++++++ .../stable_diffusion_v3/clip_tokenizer.py | 167 ++++++++++++++++++ .../clip_tokenizer_test.py | 69 ++++++++ .../t5_xxl_preprocessor.py | 84 +++++++++ .../t5_xxl_preprocessor_test.py | 74 ++++++++ .../t5_xxl_text_encoder.py | 148 ++++++++++++++++ 10 files changed, 981 insertions(+) create mode 100644 keras_nlp/src/models/stable_diffusion_v3/__init__.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py diff --git a/keras_nlp/src/models/stable_diffusion_v3/__init__.py b/keras_nlp/src/models/stable_diffusion_v3/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py b/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py new file mode 100644 index 0000000000..c4e16f8626 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py @@ -0,0 +1,103 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + + +def quick_gelu(x): + return x * ops.sigmoid(1.702 * x) + + +class CLIPEncoderBlock(layers.Layer): + def __init__( + self, + hidden_dim, + num_heads, + intermediate_dim, + intermediate_activation="quick_gelu", + **kwargs, + ): + super().__init__(**kwargs) + if hidden_dim % num_heads != 0: + raise ValueError( + "`hidden_dim` must be divisible by `num_heads`. " + f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}" + ) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.intermediate_activation = intermediate_activation + + if intermediate_activation == "quick_gelu": + intermediate_activation = quick_gelu + + self.layer_norm_1 = layers.LayerNormalization( + epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1" + ) + self.attention = layers.MultiHeadAttention( + num_heads, + hidden_dim // num_heads, + dtype=self.dtype_policy, + name="attention", + ) + self.layer_norm_2 = layers.LayerNormalization( + epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2" + ) + self.dense_1 = layers.Dense( + self.intermediate_dim, dtype=self.dtype_policy, name="dense_1" + ) + self.activation = layers.Activation( + intermediate_activation, dtype=self.dtype_policy, name="activation" + ) + self.dense_2 = layers.Dense( + self.hidden_dim, dtype=self.dtype_policy, name="dense_2" + ) + + def build(self, input_shape): + self.layer_norm_1.build(input_shape) + self.attention.build(input_shape, input_shape, input_shape) + self.layer_norm_2.build(input_shape) + self.dense_1.build(input_shape) + input_shape = self.dense_1.compute_output_shape(input_shape) + self.dense_2.build(input_shape) + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.hidden_dim + return outputs_shape + + def call(self, x, training=None): + residual = x + x = self.layer_norm_1(x) + x = self.attention(x, x, x, training=training, use_causal_mask=True) + x = ops.add(residual, x) + + residual = x + x = self.dense_1(self.layer_norm_2(residual)) + x = self.activation(x) + x = self.dense_2(x) + x = ops.add(residual, x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "intermediate_activation": self.intermediate_activation, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py new file mode 100644 index 0000000000..ca1b6c598e --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py @@ -0,0 +1,104 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) + +try: + import tensorflow as tf +except ImportError: + tf = None + + +class CLIPPreprocessor(Preprocessor): + tokenizer_cls = CLIPTokenizer + + def __init__( + self, + tokenizer, + sequence_length=77, + add_start_token=True, + add_end_token=False, + to_lower=True, + pad_with_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.to_lower = to_lower + self.pad_with_end_token = pad_with_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + pad_value = self.tokenizer.pad_token_id + if self.pad_with_end_token: + pad_value = self.tokenizer.end_token_id + + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=pad_value, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + # TODO: Use `@tf_preprocessing_function` after rebasing. + def call(self, x, y=None, sample_weight=None, sequence_length=None): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "T5XXL requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using T5XXL" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + if self.to_lower: + x = tf.strings.lower(x) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + "to_lower": self.to_lower, + "pad_with_end_token": self.pad_with_end_token, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py new file mode 100644 index 0000000000..4365a14673 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.clip_preprocessor import ( + CLIPPreprocessor, +) +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CLIPPreprocessorTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i + 1) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + self.tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = [" airplane airport"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=CLIPPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = [" airplane airport"] * 4 + preprocessor = CLIPPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + pad_with_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = " airplane airport" + preprocessor = CLIPPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [5, 1, 2, 1]) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("TODO") + for preset in CLIPPreprocessor.presets: + self.run_preset_test( + cls=CLIPPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py new file mode 100644 index 0000000000..d4a5cbc94f --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py @@ -0,0 +1,141 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + +from keras_nlp.src.layers.modeling.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import ( + CLIPEncoderBlock, +) + + +class CLIPTextEncoder(Backbone): + def __init__( + self, + embedding_dim, + hidden_dim, + num_layers, + num_heads, + intermediate_dim, + intermediate_activation="quick_gelu", + intermediate_output_index=None, + vocabulary_size=49408, + sequence_length=77, + dtype=None, + **kwargs, + ): + if ( + intermediate_output_index is not None + and intermediate_output_index < 0 + ): + intermediate_output_index += num_layers + + # === Layers === + self.embedding = TokenAndPositionEmbedding( + vocabulary_size=vocabulary_size, + sequence_length=sequence_length, + embedding_dim=embedding_dim, + dtype=dtype, + name="embedding", + ) + self.encoder_layers = [ + CLIPEncoderBlock( + hidden_dim, + num_heads, + intermediate_dim, + intermediate_activation, + dtype=dtype, + ) + for _ in range(num_layers) + ] + self.layer_norm = layers.LayerNormalization( + epsilon=0.00001, dtype=dtype, name="layer_norm" + ) + self.text_projection = layers.Dense( + hidden_dim, + use_bias=False, + dtype=dtype, + name="text_projection", + ) + + # === Functional Model === + encoder_token_ids = layers.Input( + shape=(sequence_length,), dtype="int32", name="encoder_token_ids" + ) + x = self.embedding(encoder_token_ids) + encoder_intermediate_output = None + # Encoder. + for i, block in enumerate(self.encoder_layers): + x = block(x) + if i == intermediate_output_index: + encoder_intermediate_output = x + x = self.layer_norm(x) + encoder_output = x + if encoder_intermediate_output is not None: + encoder_intermediate_output = self.layer_norm( + encoder_intermediate_output + ) + # Projection. + indices = ops.expand_dims( + ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1 + ) + pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1) + pooled_output = ops.squeeze(pooled_output, axis=1) + projection_output = self.text_projection(pooled_output) + + outputs = { + "encoder_sequence_output": encoder_output, + "encoder_pooled_output": pooled_output, + "encoder_projection_output": projection_output, + } + if intermediate_output_index is not None: + outputs["encoder_intermediate_output"] = encoder_intermediate_output + + super().__init__( + inputs={"encoder_token_ids": encoder_token_ids}, + outputs=outputs, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.intermediate_activation = intermediate_activation + self.intermediate_output_index = intermediate_output_index + self.vocabulary_size = vocabulary_size + self.sequence_length = sequence_length + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_dim": self.embedding_dim, + "hidden_dim": self.hidden_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "intermediate_activation": self.intermediate_activation, + "intermediate_output_index": self.intermediate_output_index, + "vocabulary_size": self.vocabulary_size, + "sequence_length": self.sequence_length, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py new file mode 100644 index 0000000000..59c046d9f5 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py @@ -0,0 +1,167 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch +from keras_nlp.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe + +try: + import tensorflow as tf +except ImportError: + tf = None + + +class CLIPTokenizer(BytePairTokenizer): + def __init__(self, vocabulary=None, merges=None, **kwargs): + self.start_token = "<|startoftext|>" + self.end_token = "<|endoftext|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.start_token, self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + + # In StableDiffusionV3, we need to add `` to the last word. + words = tf.strings.reduce_join(words, axis=1, separator=" ") + words = tf.strings.join([words, ""]) + words = tf.strings.split(words, sep=" ") + + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, axis=1, separator=" " + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + + # Strip and remove empty tokens. + raw_tokens = tf.strings.strip(raw_tokens) + raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "") + + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens + + def detokenize(self, inputs): + self._check_vocabulary() + inputs, unbatched, _ = convert_to_ragged_batch(inputs) + inputs = tf.cast(inputs, self.dtype) + unicode_text = tf.strings.reduce_join( + self.id_to_token_map.lookup(inputs), axis=-1 + ) + + # When detokenizing, we need to remove and extra whitespace. + unicode_text = tf.strings.regex_replace(unicode_text, r"", " ") + unicode_text = tf.strings.strip(unicode_text) + + split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") + outputs = tf.strings.reduce_join( + self.unicode2byte.lookup(split_unicode_text), axis=-1 + ) + + if unbatched: + outputs = tf.squeeze(outputs, 0) + return outputs + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py new file mode 100644 index 0000000000..4ceaea8057 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CLIPTokenizerTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + self.merges = merges + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = ["airplane ", " airport"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=CLIPTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + # Whitespaces should be removed. + expected_output=[[0, 1], [0, 2]], + expected_detokenize_output=["airplane", "airport"], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + CLIPTokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.skipTest( + "TODO: Add preset from `hf://openai/clip-vit-large-patch14`" + ) + self.run_preset_test( + cls=CLIPTokenizer, + preset="llama3_8b_en", + input_data=["The quick brown fox."], + expected_output=[[791, 4062, 14198, 39935, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest( + "TODO: Add preset from `hf://openai/clip-vit-large-patch14`" + ) + for preset in CLIPTokenizer.presets: + self.run_preset_test( + cls=CLIPTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py new file mode 100644 index 0000000000..c8b6ef7566 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py @@ -0,0 +1,84 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) + + +class T5XXLPreprocessor(Preprocessor): + tokenizer_cls = T5Tokenizer + + def __init__( + self, + tokenizer, + sequence_length=256, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call(self, x, y=None, sample_weight=None, sequence_length=None): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "T5XXL requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using T5XXL" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py new file mode 100644 index 0000000000..90b7dfaf9c --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.t5_xxl_preprocessor import ( + T5XXLPreprocessor, +) +from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class T5XXLPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = T5Tokenizer( + proto=os.path.join(self.get_test_data_dir(), "t5_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 10, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=T5XXLPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[4, 9, 5, 7, 1, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = T5XXLPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = T5XXLPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [4, 9, 5, 1]) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("TODO") + for preset in T5XXLPreprocessor.presets: + self.run_preset_test( + cls=T5XXLPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py new file mode 100644 index 0000000000..9f4e5ae3a1 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py @@ -0,0 +1,148 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_nlp.src.models.t5.t5_transformer_layer import T5TransformerLayer + + +class T5XXLTextEncoder(Backbone): + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + key_value_dim=None, + dropout=0.1, + activation="relu", + use_gated_activation=True, + layer_norm_epsilon=1e-06, + tie_embedding_weights=True, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_embedding_weights, + embeddings_initializer=keras.initializers.TruncatedNormal(1.0), + dtype=dtype, + name="token_embedding", + ) + self.encoder_embedding_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="encoder_embedding_dropout", + ) + self.encoder_transformer_layers = [] + for i in range(num_layers): + layer = T5TransformerLayer( + is_decoder=False, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + key_value_dim=key_value_dim or hidden_dim // num_heads, + dropout=dropout, + activation=activation, + layer_norm_epsilon=layer_norm_epsilon, + num_heads=num_heads, + use_gated_activation=use_gated_activation, + use_relative_attention_bias=bool(i == 0), + dtype=dtype, + name=f"transformer_encoder_layer_{i}", + ) + self.encoder_transformer_layers.append(layer) + self.encoder_layer_norm = T5LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="encoder_output_layer_norm", + ) + self.encoder_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="encoder_output_dropout", + ) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + # Encoder. + x = self.token_embedding(encoder_token_id_input) + x = self.encoder_embedding_dropout(x) + encoder_attention_mask = encoder_padding_mask_input[:, None, :] + position_bias = None + for transformer_layer in self.encoder_transformer_layers: + output = transformer_layer( + x, + attention_mask=encoder_attention_mask, + position_bias=position_bias, + use_causal_mask=False, + ) + if isinstance(output, tuple): + x, position_bias = output + x = self.encoder_layer_norm(x) + x = self.encoder_dropout(x) + encoder_output = x + + super().__init__( + { + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + }, + outputs=encoder_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.activation = keras.activations.get(activation) + self.key_value_dim = key_value_dim + self.dropout = dropout + self.use_gated_activation = use_gated_activation + self.layer_norm_epsilon = layer_norm_epsilon + self.tie_embedding_weights = tie_embedding_weights + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "activation": keras.activations.serialize(self.activation), + "key_value_dim": self.key_value_dim, + "dropout": self.dropout, + "use_gated_activation": self.use_gated_activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + "tie_embedding_weights": self.tie_embedding_weights, + } + ) + return config From 18dddf4a592e97a2d4c083ec613701c70128e90c Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 26 Aug 2024 21:19:47 +0000 Subject: [PATCH 12/33] Add DeepLabV3Plus segmentation --- .../src/models/deeplab_v3_plus/__init__.py | 13 + .../deeplab_v3_plus/deeplab_v3_plus_layers.py | 184 ++++++++++++++ .../deeplab_v3_plus_segmenter.py | 233 ++++++++++++++++++ .../deeplab_v3_plus/deeplab_v3_plus_test.py | 47 ++++ 4 files changed, 477 insertions(+) create mode 100644 keras_nlp/src/models/deeplab_v3_plus/__init__.py create mode 100644 keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py create mode 100644 keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py create mode 100644 keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py diff --git a/keras_nlp/src/models/deeplab_v3_plus/__init__.py b/keras_nlp/src/models/deeplab_v3_plus/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3_plus/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py new file mode 100644 index 0000000000..ba46a11652 --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py @@ -0,0 +1,184 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import List +from typing import Mapping + +import keras +from keras import ops + + +class SpatialPyramidPooling(keras.layers.Layer): + """Implements the Atrous Spatial Pyramid Pooling. + + Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution + for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/pdf/1802.02611.pdf) + + inp = keras.layers.Input((384, 384, 3)) + backbone = keras.applications.EfficientNetB0( + input_tensor=inp, + include_top=False) + output = backbone(inp) + output = SpatialPyramidPooling( + dilation_rates=[6, 12, 18])(output) + + # output[4].shape = [None, 16, 16, 256] + """ + + def __init__( + self, + dilation_rates: List[int], + num_channels: int = 256, + activation: str = "relu", + dropout: float = 0.0, + **kwargs, + ): + """Initializes an Atrous Spatial Pyramid Pooling layer. + + Args: + dilation_rates: A `list` of integers for parallel dilated conv. + Usually a sample choice of rates are [6, 12, 18]. + num_channels: An `int` number of output channels, defaults to 256. + activation: A `str` activation to be used, defaults to 'relu'. + dropout: A `float` for the dropout rate of the final projection + output after the activations and batch norm, defaults to 0.0, + which means no dropout is applied to the output. + **kwargs: Additional keyword arguments to be passed. + """ + super().__init__(**kwargs) + self.dilation_rates = dilation_rates + self.num_channels = num_channels + self.activation = activation + self.dropout = dropout + + def build(self, input_shape): + channels = input_shape[3] + + # This is the parallel networks that process the input features with + # different dilation rates. The output from each channel will be merged + # together and feed to the output. + self.aspp_parallel_channels = [] + + # Channel1 with Conv2D and 1x1 kernel size. + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + ), + keras.layers.BatchNormalization(), + keras.layers.Activation(self.activation), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Channel 2 and afterwards are based on self.dilation_rates, and each of + # them will have conv2D with 3x3 kernel size. + for dilation_rate in self.dilation_rates: + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(3, 3), + padding="same", + dilation_rate=dilation_rate, + use_bias=False, + ), + keras.layers.BatchNormalization(), + keras.layers.Activation(self.activation), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Last channel is the global average pooling with conv2D 1x1 kernel. + pool_sequential = keras.Sequential( + [ + keras.layers.GlobalAveragePooling2D(), + keras.layers.Reshape((1, 1, channels)), + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + ), + keras.layers.BatchNormalization(), + keras.layers.Activation(self.activation), + ] + ) + pool_sequential.build(input_shape) + self.aspp_parallel_channels.append(pool_sequential) + + # Final projection layers + projection = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + ), + keras.layers.BatchNormalization(), + keras.layers.Activation(self.activation), + keras.layers.Dropout(rate=self.dropout), + ], + ) + projection_input_channels = ( + 2 + len(self.dilation_rates) + ) * self.num_channels + projection.build(tuple(input_shape[:-1]) + (projection_input_channels,)) + self.projection = projection + + def call(self, inputs, training=None): + """Calls the Atrous Spatial Pyramid Pooling layer on an input. + + Args: + inputs: A tensor of shape [batch, height, width, channels] + + Returns: + A tensor of shape [batch, height, width, num_channels] + """ + result = [] + + for channel in self.aspp_parallel_channels: + temp = ops.cast(channel(inputs, training=training), inputs.dtype) + result.append(temp) + + image_shape = ops.shape(inputs) + height, width = image_shape[1], image_shape[2] + result[-1] = keras.layers.Resizing( + height, + width, + interpolation="bilinear", + )(result[-1]) + + result = ops.concatenate(result, axis=-1) + result = self.projection(result, training=training) + return result + + def compute_output_shape(self, input_shape): + return tuple(input_shape[:-1]) + (self.num_channels,) + + def get_config(self) -> Mapping[str, Any]: + config = { + "dilation_rates": self.dilation_rates, + "num_channels": self.num_channels, + "activation": self.activation, + "dropout": self.dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py new file mode 100644 index 0000000000..759c264279 --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py @@ -0,0 +1,233 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_layers import ( + SpatialPyramidPooling, +) +from keras_nlp.src.models.task import Task + +@keras_nlp_export( + [ + "keras_nlp.models.DeepLabV3Plus", + "keras_nlp.models.segmentation.DeepLabV3Plus", + ] +) +class DeepLabV3Plus(Task): + """DeepLabV3+ architecture for semantic segmentation. + + This class implements a DeepLabV3+ architecture as described in + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/abs/1802.02611)(ECCV 2018) + and [Rethinking Atrous Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1706.05587)(CVPR 2017) + + Args: + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the DeepLabV3+ Encoder. Should + either be a `keras_nlp.models.backbones.backbone.Backbone` or a + `keras.Model` that implements the `pyramid_outputs` + property with keys "P2", "P3" etc as values. A + somewhat sensible backbone to use in many cases is the + `keras_nlp.models.ResNetBackbone.from_preset("resnet_v2_50")`. + num_classes: int, the number of classes for the detection model. Note + that the `num_classes` contains the background class, and the + classes from the data should be represented by integers with range + [0, `num_classes`). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + low_level_feature_key: str, layer level to extract the feature from one of the + key from the `backbone` `pyramid_outputs` + property such as "P2", "P3" etc. + spatial_pyramid_pooling_key: str, layer level to extract and perform + `spatial_pyramid_pooling`, one of the key from the `backbone` `pyramid_outputs` + property such as "P4", "P5" etc. + spatial_pyramid_pooling: (Optional) a `keras.layers.Layer`. Also known + as Atrous Spatial Pyramid Pooling (ASPP). Performs spatial pooling + on different spatial levels in the pyramid, with dilation. If + provided, the feature map from the backbone is passed to it inside + the DeepLabV3 Encoder, otherwise SpatialPyramidPooling layer is used. + dialtion_rates: (Optional) A `list` of integers for parallel dilated conv. + Applied only when Default `SpatialPyramidPooling` is used. Usually a + sample choice of rates are [6, 12, 18]. + segmentation_head: (Optional) a `keras.layers.Layer`. If provided, the + outputs of the DeepLabV3 encoder is passed to this layer and it + should predict the segmentation mask based on feature from backbone + and feature from decoder, otherwise a default DeepLabV3 + convolutional head is used. + + Example: + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_nlp.models.ResNetBackbone.from_preset("resnet_v2_50") + + model = keras_hub.models.DeepLabV3Plus( + backbone= backbone, + num_classes=3, + projection_filters=48, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + ) + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + + def __init__( + self, + backbone, + num_classes, + low_level_feature_key, + spatial_pyramid_pooling_key, + projection_filters=48, + spatial_pyramid_pooling=None, + dialtion_rates=None, + segmentation_head=None, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"backbone={backbone} (of type {type(backbone)})." + ) + + # === Functional Model === + inputs = backbone.input + + if spatial_pyramid_pooling is None: + spatial_pyramid_pooling = SpatialPyramidPooling( + dilation_rates=dialtion_rates + ) + spatial_backbone_features = backbone.pyramid_outputs[spatial_pyramid_pooling_key] + spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) + + low_level_backbone_feature = backbone.pyramid_outputs[low_level_feature_key] + low_level_projected_features = apply_low_level_feature_network(low_level_backbone_feature, projection_filters) + + encoder_outputs = keras.layers.UpSampling2D( + size=(8, 8), + interpolation="bilinear", + name="encoder_output_upsampling", + )(spp_outputs) + + combined_encoder_outputs = keras.layers.Concatenate(axis=-1)( + [encoder_outputs, low_level_projected_features] + ) + + if segmentation_head is None: + x = keras.layers.Conv2D( + name="segmentation_head_conv", + filters=256, + kernel_size=1, + padding="same", + use_bias=False, + )(combined_encoder_outputs) + x = keras.layers.BatchNormalization( + name="segmentation_head_norm" + )(x) + x = keras.layers.ReLU(name="segmentation_head_relu")(x) + x = keras.layers.UpSampling2D( + size=(4, 4), interpolation="bilinear" + )(x) + # Classification layer + outputs = keras.layers.Conv2D( + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + # Force the dtype of the classification layer to float32 + # to avoid the NAN loss issue when used with mixed + # precision API. + dtype="float32", + )(x) + else: + outputs = segmentation_head(combined_encoder_outputs) + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + # === Config === + self.num_classes = num_classes + self.backbone = backbone + self.spatial_pyramid_pooling = spatial_pyramid_pooling + self.projection_filters = projection_filters + self.segmentation_head = segmentation_head + self.dialtion_rates = dialtion_rates + self.low_level_feature_key = low_level_feature_key + self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key + + def get_config(self): + return { + "num_classes": self.num_classes, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "spatial_pyramid_pooling": keras.saving.serialize_keras_object( + self.spatial_pyramid_pooling + ), + "projection_filters": self.projection_filters, + "segmentation_head": keras.saving.serialize_keras_object( + self.segmentation_head + ), + "dialtion_rates": self.dialtion_rates, + "low_level_feature_key": self.low_level_feature_key, + "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, + } + + + @classmethod + def from_config(cls, config): + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + if "spatial_pyramid_pooling" in config and isinstance( + config["spatial_pyramid_pooling"], dict + ): + config["spatial_pyramid_pooling"] = keras.layers.deserialize( + config["spatial_pyramid_pooling"] + ) + if "segmentation_head" in config and isinstance( + config["segmentation_head"], dict + ): + config["segmentation_head"] = keras.layers.deserialize( + config["segmentation_head"] + ) + return super().from_config(config) + +def apply_low_level_feature_network(input_tensor, projection_filters): + x = keras.layers.Conv2D( + name="low_level_feature_conv", + filters=projection_filters, + kernel_size=1, + padding="same", + use_bias=False + )(input_tensor) + + x = keras.layers.BatchNormalization(name="low_level_feature_norm")(x) + x = keras.layers.ReLU(name="low_level_feature_relu")(x) + return x + + + diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py new file mode 100644 index 0000000000..7e50d4ec8e --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized +from keras import models +from keras import ops +import numpy as np + +from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import DeepLabV3Plus +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + +class DeepLabV3PlusTest(TestCase): + def setUp(self): + self.init_kwargs = { + "backbone": ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k"), + "num_classes":2, + "low_level_feature_key":"P2", + "spatial_pyramid_pooling_key":"P5", + "projection_filters":48, + "spatial_pyramid_pooling":None, + "dialtion_rates":[6, 12, 18], + "segmentation_head":None, + } + self.images = np.ones((2, 96, 96, 3), dtype="float32") + self.labels = np.zeros((2, 96, 96, 1), dtype="float32") + + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3Plus, + init_kwargs=self.init_kwargs, + ) + \ No newline at end of file From 744b2336489dc1bb96130d815458e564d878b726 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 26 Aug 2024 21:23:23 +0000 Subject: [PATCH 13/33] init file --- .../models/deeplab_v3_plus/deeplab_v3_plus_test.py | 9 ++++++++- keras_nlp/src/models/vit_det/__init__.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 keras_nlp/src/models/vit_det/__init__.py diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py index 7e50d4ec8e..02c15a811d 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py @@ -36,7 +36,14 @@ def setUp(self): } self.images = np.ones((2, 96, 96, 3), dtype="float32") self.labels = np.zeros((2, 96, 96, 1), dtype="float32") - + + def test_task_basics(self): + self.run_task_test( + cls=DeepLabV3Plus, + init_kwargs=self.init_kwargs, + train_data=(self.images, self.labels), + expected_output_shape = (2, 96, 96, 1), + ) @pytest.mark.large def test_saved_model(self): diff --git a/keras_nlp/src/models/vit_det/__init__.py b/keras_nlp/src/models/vit_det/__init__.py new file mode 100644 index 0000000000..2351a1b7b4 --- /dev/null +++ b/keras_nlp/src/models/vit_det/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file From 98c081141465545221a5cfcc322cc683518a9309 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 26 Aug 2024 21:31:28 +0000 Subject: [PATCH 14/33] api gen --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/api/models/segmentation/__init__.py | 22 ++++++ .../deeplab_v3_plus_segmenter.py | 76 ++++++++++--------- .../deeplab_v3_plus/deeplab_v3_plus_test.py | 37 ++++----- keras_nlp/src/models/vit_det/__init__.py | 2 +- 5 files changed, 86 insertions(+), 55 deletions(-) create mode 100644 keras_nlp/api/models/segmentation/__init__.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index c6d8ed7d32..624c66445d 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -17,6 +17,7 @@ since your modifications would be overwritten. """ +from keras_nlp.api.models import segmentation from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.models.albert.albert_classifier import AlbertClassifier from keras_nlp.src.models.albert.albert_masked_lm import AlbertMaskedLM @@ -74,6 +75,9 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( + DeepLabV3Plus, +) from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_nlp.src.models.densenet.densenet_image_classifier import ( DenseNetImageClassifier, diff --git a/keras_nlp/api/models/segmentation/__init__.py b/keras_nlp/api/models/segmentation/__init__.py new file mode 100644 index 0000000000..c1da3bb01a --- /dev/null +++ b/keras_nlp/api/models/segmentation/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( + DeepLabV3Plus, +) diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py index 759c264279..eaae8911a5 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py @@ -19,6 +19,7 @@ ) from keras_nlp.src.models.task import Task + @keras_nlp_export( [ "keras_nlp.models.DeepLabV3Plus", @@ -28,8 +29,8 @@ class DeepLabV3Plus(Task): """DeepLabV3+ architecture for semantic segmentation. - This class implements a DeepLabV3+ architecture as described in - [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + This class implements a DeepLabV3+ architecture as described in + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)(ECCV 2018) and [Rethinking Atrous Convolution for Semantic Image Segmentation]( https://arxiv.org/abs/1706.05587)(CVPR 2017) @@ -51,7 +52,7 @@ class DeepLabV3Plus(Task): low_level_feature_key: str, layer level to extract the feature from one of the key from the `backbone` `pyramid_outputs` property such as "P2", "P3" etc. - spatial_pyramid_pooling_key: str, layer level to extract and perform + spatial_pyramid_pooling_key: str, layer level to extract and perform `spatial_pyramid_pooling`, one of the key from the `backbone` `pyramid_outputs` property such as "P4", "P5" etc. spatial_pyramid_pooling: (Optional) a `keras.layers.Layer`. Also known @@ -115,19 +116,25 @@ def __init__( f" or `keras.Model`. Received instead " f"backbone={backbone} (of type {type(backbone)})." ) - + # === Functional Model === inputs = backbone.input - + if spatial_pyramid_pooling is None: spatial_pyramid_pooling = SpatialPyramidPooling( dilation_rates=dialtion_rates - ) - spatial_backbone_features = backbone.pyramid_outputs[spatial_pyramid_pooling_key] + ) + spatial_backbone_features = backbone.pyramid_outputs[ + spatial_pyramid_pooling_key + ] spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) - low_level_backbone_feature = backbone.pyramid_outputs[low_level_feature_key] - low_level_projected_features = apply_low_level_feature_network(low_level_backbone_feature, projection_filters) + low_level_backbone_feature = backbone.pyramid_outputs[ + low_level_feature_key + ] + low_level_projected_features = apply_low_level_feature_network( + low_level_backbone_feature, projection_filters + ) encoder_outputs = keras.layers.UpSampling2D( size=(8, 8), @@ -141,37 +148,37 @@ def __init__( if segmentation_head is None: x = keras.layers.Conv2D( - name="segmentation_head_conv", - filters=256, - kernel_size=1, - padding="same", - use_bias=False, + name="segmentation_head_conv", + filters=256, + kernel_size=1, + padding="same", + use_bias=False, )(combined_encoder_outputs) - x = keras.layers.BatchNormalization( - name="segmentation_head_norm" - )(x) + x = keras.layers.BatchNormalization(name="segmentation_head_norm")( + x + ) x = keras.layers.ReLU(name="segmentation_head_relu")(x) x = keras.layers.UpSampling2D( - size=(4, 4), interpolation="bilinear" + size=(4, 4), interpolation="bilinear" )(x) # Classification layer outputs = keras.layers.Conv2D( - name="segmentation_output", - filters=num_classes, - kernel_size=1, - use_bias=False, - padding="same", - # Force the dtype of the classification layer to float32 - # to avoid the NAN loss issue when used with mixed - # precision API. - dtype="float32", + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + # Force the dtype of the classification layer to float32 + # to avoid the NAN loss issue when used with mixed + # precision API. + dtype="float32", )(x) else: outputs = segmentation_head(combined_encoder_outputs) - + super().__init__(inputs=inputs, outputs=outputs, **kwargs) - - # === Config === + + # === Config === self.num_classes = num_classes self.backbone = backbone self.spatial_pyramid_pooling = spatial_pyramid_pooling @@ -197,7 +204,6 @@ def get_config(self): "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, } - @classmethod def from_config(cls, config): if "backbone" in config and isinstance(config["backbone"], dict): @@ -216,18 +222,16 @@ def from_config(cls, config): ) return super().from_config(config) + def apply_low_level_feature_network(input_tensor, projection_filters): x = keras.layers.Conv2D( name="low_level_feature_conv", filters=projection_filters, kernel_size=1, padding="same", - use_bias=False + use_bias=False, )(input_tensor) - + x = keras.layers.BatchNormalization(name="low_level_feature_norm")(x) x = keras.layers.ReLU(name="low_level_feature_relu")(x) return x - - - diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py index 02c15a811d..8151012798 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py @@ -12,43 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from absl.testing import parameterized -from keras import models -from keras import ops import numpy as np +import pytest -from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import DeepLabV3Plus +from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( + DeepLabV3Plus, +) from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone from keras_nlp.src.tests.test_case import TestCase + class DeepLabV3PlusTest(TestCase): def setUp(self): self.init_kwargs = { - "backbone": ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k"), - "num_classes":2, - "low_level_feature_key":"P2", - "spatial_pyramid_pooling_key":"P5", - "projection_filters":48, - "spatial_pyramid_pooling":None, - "dialtion_rates":[6, 12, 18], - "segmentation_head":None, + "backbone": ResNetBackbone.from_preset( + "hf://timm/resnet18.a1_in1k" + ), + "num_classes": 2, + "low_level_feature_key": "P2", + "spatial_pyramid_pooling_key": "P5", + "projection_filters": 48, + "spatial_pyramid_pooling": None, + "dialtion_rates": [6, 12, 18], + "segmentation_head": None, } self.images = np.ones((2, 96, 96, 3), dtype="float32") self.labels = np.zeros((2, 96, 96, 1), dtype="float32") - - def test_task_basics(self): + + def test_task_basics(self): self.run_task_test( cls=DeepLabV3Plus, init_kwargs=self.init_kwargs, train_data=(self.images, self.labels), - expected_output_shape = (2, 96, 96, 1), + expected_output_shape=(2, 96, 96, 1), ) - + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( cls=DeepLabV3Plus, init_kwargs=self.init_kwargs, ) - \ No newline at end of file diff --git a/keras_nlp/src/models/vit_det/__init__.py b/keras_nlp/src/models/vit_det/__init__.py index 2351a1b7b4..3364a6bd16 100644 --- a/keras_nlp/src/models/vit_det/__init__.py +++ b/keras_nlp/src/models/vit_det/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. From b40617ca62d1c9dbbfd171dd54ea1e78d9d2dae6 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 26 Aug 2024 23:54:18 +0000 Subject: [PATCH 15/33] Add Segmentation base class --- .../deeplab_v3_plus_segmenter.py | 4 +- .../deeplab_v3_plus/deeplab_v3_plus_test.py | 8 +- keras_nlp/src/models/segmentation.py | 105 ++++++++++++++++++ keras_nlp/src/tests/test_case.py | 29 +++++ 4 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 keras_nlp/src/models/segmentation.py diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py index eaae8911a5..e4e7a0a8b4 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py @@ -17,7 +17,7 @@ from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_layers import ( SpatialPyramidPooling, ) -from keras_nlp.src.models.task import Task +from keras_nlp.src.models.segmentation import Segmentation @keras_nlp_export( @@ -26,7 +26,7 @@ "keras_nlp.models.segmentation.DeepLabV3Plus", ] ) -class DeepLabV3Plus(Task): +class DeepLabV3Plus(Segmentation): """DeepLabV3+ architecture for semantic segmentation. This class implements a DeepLabV3+ architecture as described in diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py index 8151012798..67eb3a1061 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py @@ -37,14 +37,14 @@ def setUp(self): "segmentation_head": None, } self.images = np.ones((2, 96, 96, 3), dtype="float32") - self.labels = np.zeros((2, 96, 96, 1), dtype="float32") + self.labels = np.zeros((2, 96, 96, 2), dtype="float32") - def test_task_basics(self): - self.run_task_test( + def test_segmentation_basics(self): + self.run_segmentation_test( cls=DeepLabV3Plus, init_kwargs=self.init_kwargs, train_data=(self.images, self.labels), - expected_output_shape=(2, 96, 96, 1), + expected_output_shape=(2, 96, 96, 2), ) @pytest.mark.large diff --git a/keras_nlp/src/models/segmentation.py b/keras_nlp/src/models/segmentation.py new file mode 100644 index 0000000000..6827c7ab42 --- /dev/null +++ b/keras_nlp/src/models/segmentation.py @@ -0,0 +1,105 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.Segmentation") +class Segmentation(Task): + """Base class for all segmentation tasks. + + `Segmentation` tasks wrap a `keras_nlp.models.Backbone` to create a model + that can be used for segmentation. + `Segmentation` tasks take an additional + `num_classes` argument, the number of segmentation classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a image and `y` is a label from `[0, num_classes)`. + + All `Segmentation` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + model = keras_nlp.models.Segmentation.from_preset( + "basnet_resnet", + num_classes=2, + ) + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + output = model(images) + pred_labels = output[0] + + model.fit(images, labels, epochs=3) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Segmenter` task for training. + + The `Segmenter` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.BinaryCrossentropy` loss will be + applied for the segmentation task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.Accuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.BinaryCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.Accuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) \ No newline at end of file diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 8e63bc19d9..69787c1968 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -557,6 +557,35 @@ def run_task_test( task.preprocessor = None task.fit(ds.map(preprocessor)) task.preprocessor = preprocessor + + def run_segmentation_test( + self, + cls, + init_kwargs, + train_data, + expected_output_shape=None, + batch_size=2, + ): + """Run basic tests for a backbone, including compilation.""" + task = cls(**init_kwargs) + # Check serialization (without a full save). + self.run_serialization_test(task) + ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size) + x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data) + + # Test predict. + output = task.predict(x) + if expected_output_shape is not None: + output_shape = tree.map_structure(lambda x: x.shape, output) + self.assertAllClose(output_shape, expected_output_shape) + # With a dataset. + output_ds = task.predict(ds) + self.assertAllClose(output, output_ds) + + # Test fit. + task.fit(x, y, sample_weight=sw) + # With a dataset. + task.fit(ds) def run_preset_test( self, From 7470b848069926c1639a488d131fcc323091c82f Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 27 Aug 2024 00:38:44 +0000 Subject: [PATCH 16/33] format fix --- keras_nlp/api/models/__init__.py | 1 + keras_nlp/src/models/segmentation.py | 6 +++--- keras_nlp/src/tests/test_case.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 624c66445d..fe12c3218b 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -216,6 +216,7 @@ RobertaPreprocessor, ) from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.src.models.segmentation import Segmentation from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer diff --git a/keras_nlp/src/models/segmentation.py b/keras_nlp/src/models/segmentation.py index 6827c7ab42..bba775667d 100644 --- a/keras_nlp/src/models/segmentation.py +++ b/keras_nlp/src/models/segmentation.py @@ -40,10 +40,10 @@ class Segmentation(Task): ) images = np.ones(shape=(1, 288, 288, 3)) labels = np.zeros(shape=(1, 288, 288, 1)) - + output = model(images) pred_labels = output[0] - + model.fit(images, labels, epochs=3) ``` """ @@ -102,4 +102,4 @@ def compile( loss=loss, metrics=metrics, **kwargs, - ) \ No newline at end of file + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 69787c1968..797b81335f 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -557,7 +557,7 @@ def run_task_test( task.preprocessor = None task.fit(ds.map(preprocessor)) task.preprocessor = preprocessor - + def run_segmentation_test( self, cls, From 68a5a627ac0cff68d82213080cdef1d48c32390b Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 27 Aug 2024 17:49:48 +0000 Subject: [PATCH 17/33] add dependency package --- requirements-common.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-common.txt b/requirements-common.txt index 4e90ca9fab..27596bfb13 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -2,6 +2,7 @@ dm-tree regex rich +huggingface_hub kagglehub # Tooling deps. astor From 84731701c6ce7f5485d0a532e6e11dadebcccf7a Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 28 Aug 2024 17:47:08 +0000 Subject: [PATCH 18/33] nit --- .../deeplab_v3_plus/deeplab_v3_plus_segmenter.py | 10 +++++----- .../src/models/deeplab_v3_plus/deeplab_v3_plus_test.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py index e4e7a0a8b4..c41b6b9cb7 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py @@ -60,7 +60,7 @@ class DeepLabV3Plus(Segmentation): on different spatial levels in the pyramid, with dilation. If provided, the feature map from the backbone is passed to it inside the DeepLabV3 Encoder, otherwise SpatialPyramidPooling layer is used. - dialtion_rates: (Optional) A `list` of integers for parallel dilated conv. + dilation_rates: (Optional) A `list` of integers for parallel dilated conv. Applied only when Default `SpatialPyramidPooling` is used. Usually a sample choice of rates are [6, 12, 18]. segmentation_head: (Optional) a `keras.layers.Layer`. If provided, the @@ -104,7 +104,7 @@ def __init__( spatial_pyramid_pooling_key, projection_filters=48, spatial_pyramid_pooling=None, - dialtion_rates=None, + dilation_rates=None, segmentation_head=None, **kwargs, ): @@ -122,7 +122,7 @@ def __init__( if spatial_pyramid_pooling is None: spatial_pyramid_pooling = SpatialPyramidPooling( - dilation_rates=dialtion_rates + dilation_rates=dilation_rates ) spatial_backbone_features = backbone.pyramid_outputs[ spatial_pyramid_pooling_key @@ -184,7 +184,7 @@ def __init__( self.spatial_pyramid_pooling = spatial_pyramid_pooling self.projection_filters = projection_filters self.segmentation_head = segmentation_head - self.dialtion_rates = dialtion_rates + self.dilation_rates = dilation_rates self.low_level_feature_key = low_level_feature_key self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key @@ -199,7 +199,7 @@ def get_config(self): "segmentation_head": keras.saving.serialize_keras_object( self.segmentation_head ), - "dialtion_rates": self.dialtion_rates, + "dilation_rates": self.dilation_rates, "low_level_feature_key": self.low_level_feature_key, "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, } diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py index 67eb3a1061..641c402309 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py +++ b/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py @@ -33,7 +33,7 @@ def setUp(self): "spatial_pyramid_pooling_key": "P5", "projection_filters": 48, "spatial_pyramid_pooling": None, - "dialtion_rates": [6, 12, 18], + "dilation_rates": [6, 12, 18], "segmentation_head": None, } self.images = np.ones((2, 96, 96, 3), dtype="float32") From beae2f410195fa399058dd27db9abb6e0eaf4eba Mon Sep 17 00:00:00 2001 From: Siva Sravana Kumar Neeli <113718461+sineeli@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:55:14 -0700 Subject: [PATCH 19/33] Add Bounding Box Utils (#1791) * Bounding box utils * - Correct test cases * - Remove hard tensorflow dtype * - fix api gen * - Fix import for test cases - Use setup for converters test case * - fix api_gen issue * - FIx api gen * - Fix api gen error * - Correct test cases as per new api changes --- keras_nlp/api/__init__.py | 1 + keras_nlp/api/bounding_box/__init__.py | 23 + keras_nlp/src/bounding_box/__init__.py | 13 + keras_nlp/src/bounding_box/converters.py | 529 ++++++++++++++++++ keras_nlp/src/bounding_box/converters_test.py | 365 ++++++++++++ keras_nlp/src/bounding_box/to_dense.py | 95 ++++ keras_nlp/src/bounding_box/to_dense_test.py | 37 ++ keras_nlp/src/bounding_box/to_ragged.py | 99 ++++ keras_nlp/src/bounding_box/to_ragged_test.py | 101 ++++ keras_nlp/src/bounding_box/validate_format.py | 99 ++++ 10 files changed, 1362 insertions(+) create mode 100644 keras_nlp/api/bounding_box/__init__.py create mode 100644 keras_nlp/src/bounding_box/__init__.py create mode 100644 keras_nlp/src/bounding_box/converters.py create mode 100644 keras_nlp/src/bounding_box/converters_test.py create mode 100644 keras_nlp/src/bounding_box/to_dense.py create mode 100644 keras_nlp/src/bounding_box/to_dense_test.py create mode 100644 keras_nlp/src/bounding_box/to_ragged.py create mode 100644 keras_nlp/src/bounding_box/to_ragged_test.py create mode 100644 keras_nlp/src/bounding_box/validate_format.py diff --git a/keras_nlp/api/__init__.py b/keras_nlp/api/__init__.py index d0dc4576c6..46b16d5d8c 100644 --- a/keras_nlp/api/__init__.py +++ b/keras_nlp/api/__init__.py @@ -17,6 +17,7 @@ since your modifications would be overwritten. """ +from keras_nlp.api import bounding_box from keras_nlp.api import layers from keras_nlp.api import metrics from keras_nlp.api import models diff --git a/keras_nlp/api/bounding_box/__init__.py b/keras_nlp/api/bounding_box/__init__.py new file mode 100644 index 0000000000..18be1cd9aa --- /dev/null +++ b/keras_nlp/api/bounding_box/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.to_dense import to_dense +from keras_nlp.src.bounding_box.to_ragged import to_ragged +from keras_nlp.src.bounding_box.validate_format import validate_format diff --git a/keras_nlp/src/bounding_box/__init__.py b/keras_nlp/src/bounding_box/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/bounding_box/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/bounding_box/converters.py b/keras_nlp/src/bounding_box/converters.py new file mode 100644 index 0000000000..0e363fc6f7 --- /dev/null +++ b/keras_nlp/src/bounding_box/converters.py @@ -0,0 +1,529 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converter functions for working with bounding box formats.""" + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +# Internal exception to propagate the fact images was not passed to a converter +# that needs it. +class RequiresImagesException(Exception): + pass + + +ALL_AXES = 4 + + +def _encode_box_to_deltas( + anchors, + boxes, + anchor_format: str, + box_format: str, + variance=None, + image_shape=None, +): + """Converts bounding_boxes from `center_yxhw` to delta format.""" + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + encoded_anchors = convert_format( + anchors, + source=anchor_format, + target="center_yxhw", + image_shape=image_shape, + ) + boxes = convert_format( + boxes, source=box_format, target="center_yxhw", image_shape=image_shape + ) + anchor_dimensions = ops.maximum( + encoded_anchors[..., 2:], keras.backend.epsilon() + ) + box_dimensions = ops.maximum(boxes[..., 2:], keras.backend.epsilon()) + # anchors be unbatched, boxes can either be batched or unbatched. + boxes_delta = ops.concatenate( + [ + (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions, + ops.log(box_dimensions / anchor_dimensions), + ], + axis=-1, + ) + if variance is not None: + boxes_delta /= variance + return boxes_delta + + +def _decode_deltas_to_boxes( + anchors, + boxes_delta, + anchor_format: str, + box_format: str, + variance=None, + image_shape=None, +): + """Converts bounding_boxes from delta format to `center_yxhw`.""" + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + + def decode_single_level(anchor, box_delta): + encoded_anchor = convert_format( + anchor, + source=anchor_format, + target="center_yxhw", + image_shape=image_shape, + ) + if variance is not None: + box_delta = box_delta * variance + # anchors be unbatched, boxes can either be batched or unbatched. + box = ops.concatenate( + [ + box_delta[..., :2] * encoded_anchor[..., 2:] + + encoded_anchor[..., :2], + ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:], + ], + axis=-1, + ) + box = convert_format( + box, + source="center_yxhw", + target=box_format, + image_shape=image_shape, + ) + return box + + if isinstance(anchors, dict) and isinstance(boxes_delta, dict): + boxes = {} + for lvl, anchor in anchors.items(): + boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl]) + return boxes + else: + return decode_single_level(anchors, boxes_delta) + + +def _center_yxhw_to_xyxy(boxes, images=None, image_shape=None): + y, x, height, width = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], + axis=-1, + ) + + +def _center_xywh_to_xyxy(boxes, images=None, image_shape=None): + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], + axis=-1, + ) + + +def _xywh_to_xyxy(boxes, images=None, image_shape=None): + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([x, y, x + width, y + height], axis=-1) + + +def _xyxy_to_center_yxhw(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + (top + bottom) / 2.0, + (left + right) / 2.0, + bottom - top, + right - left, + ], + axis=-1, + ) + + +def _rel_xywh_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + image_width * x, + image_height * y, + image_width * (x + width), + image_height * (y + height), + ], + axis=-1, + ) + + +def _xyxy_no_op(boxes, images=None, image_shape=None): + return boxes + + +def _xyxy_to_xywh(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [left, top, right - left, bottom - top], + axis=-1, + ) + + +def _xyxy_to_rel_xywh(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + left, right = ( + left / image_width, + right / image_width, + ) + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [left, top, right - left, bottom - top], + axis=-1, + ) + + +def _xyxy_to_center_xywh(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + (left + right) / 2.0, + (top + bottom) / 2.0, + right - left, + bottom - top, + ], + axis=-1, + ) + + +def _rel_xyxy_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left * image_width, right * image_width + top, bottom = top * image_height, bottom * image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _xyxy_to_rel_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left / image_width, right / image_width + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _yxyx_to_xyxy(boxes, images=None, image_shape=None): + y1, x1, y2, x2 = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([x1, y1, x2, y2], axis=-1) + + +def _rel_yxyx_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + top, left, bottom, right = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left * image_width, right * image_width + top, bottom = top * image_height, bottom * image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _xyxy_to_yxyx(boxes, images=None, image_shape=None): + x1, y1, x2, y2 = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([y1, x1, y2, x2], axis=-1) + + +def _xyxy_to_rel_yxyx(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + left, right = left / image_width, right / image_width + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [top, left, bottom, right], + axis=-1, + ) + + +TO_XYXY_CONVERTERS = { + "xywh": _xywh_to_xyxy, + "center_xywh": _center_xywh_to_xyxy, + "center_yxhw": _center_yxhw_to_xyxy, + "rel_xywh": _rel_xywh_to_xyxy, + "xyxy": _xyxy_no_op, + "rel_xyxy": _rel_xyxy_to_xyxy, + "yxyx": _yxyx_to_xyxy, + "rel_yxyx": _rel_yxyx_to_xyxy, +} + +FROM_XYXY_CONVERTERS = { + "xywh": _xyxy_to_xywh, + "center_xywh": _xyxy_to_center_xywh, + "center_yxhw": _xyxy_to_center_yxhw, + "rel_xywh": _xyxy_to_rel_xywh, + "xyxy": _xyxy_no_op, + "rel_xyxy": _xyxy_to_rel_xyxy, + "yxyx": _xyxy_to_yxyx, + "rel_yxyx": _xyxy_to_rel_yxyx, +} + + +@keras_nlp_export("keras_nlp.bounding_box.convert_format") +def convert_format( + boxes, source, target, images=None, image_shape=None, dtype="float32" +): + f"""Converts bounding_boxes from one format to another. + + Supported formats are: + - `"xyxy"`, also known as `corners` format. In this format the first four + axes represent `[left, top, right, bottom]` in that order. + - `"rel_xyxy"`. In this format, the axes are the same as `"xyxy"` but the x + coordinates are normalized using the image width, and the y axes the + image height. All values in `rel_xyxy` are in the range `(0, 1)`. + - `"xywh"`. In this format the first four axes represent + `[left, top, width, height]`. + - `"rel_xywh". In this format the first four axes represent + [left, top, width, height], just like `"xywh"`. Unlike `"xywh"`, the + values are in the range (0, 1) instead of absolute pixel values. + - `"center_xyWH"`. In this format the first two coordinates represent the x + and y coordinates of the center of the bounding box, while the last two + represent the width and height of the bounding box. + - `"center_yxHW"`. In this format the first two coordinates represent the y + and x coordinates of the center of the bounding box, while the last two + represent the height and width of the bounding box. + - `"yxyx"`. In this format the first four axes represent + [top, left, bottom, right] in that order. + - `"rel_yxyx"`. In this format, the axes are the same as `"yxyx"` but the x + coordinates are normalized using the image width, and the y axes the + image height. All values in `rel_yxyx` are in the range (0, 1). + Formats are case insensitive. It is recommended that you capitalize width + and height to maximize the visual difference between `"xyWH"` and `"xyxy"`. + + Relative formats, abbreviated `rel`, make use of the shapes of the `images` + passed. In these formats, the coordinates, widths, and heights are all + specified as percentages of the host image. `images` may be a ragged + Tensor. Note that using a ragged Tensor for images may cause a substantial + performance loss, as each image will need to be processed separately due to + the mismatching image shapes. + + Example: + + ```python + boxes = load_coco_dataset() + boxes_in_xywh = keras_nlp.bounding_box.convert_format( + boxes, + source='xyxy', + target='xyWH' + ) + ``` + + Args: + boxes: tensor representing bounding boxes in the format specified in + the `source` parameter. `boxes` can optionally have extra + dimensions stacked on the final axis to store metadata. boxes + should be a 3D tensor, with the shape `[batch_size, num_boxes, 4]`. + Alternatively, boxes can be a dictionary with key 'boxes' containing + a tensor matching the aforementioned spec. + source:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. + Used to specify the original format of the `boxes` parameter. + target:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. + Used to specify the destination format of the `boxes` parameter. + images: (Optional) a batch of images aligned with `boxes` on the first + axis. Should be at least 3 dimensions, with the first 3 dimensions + representing: `[batch_size, height, width]`. Used in some + converters to compute relative pixel values of the bounding box + dimensions. Required when transforming from a rel format to a + non-rel format. + dtype: the data type to use when transforming the boxes, defaults to + `"float32"`. + """ + if isinstance(boxes, dict): + converted_boxes = boxes.copy() + converted_boxes["boxes"] = convert_format( + boxes["boxes"], + source=source, + target=target, + images=images, + image_shape=image_shape, + dtype=dtype, + ) + return converted_boxes + + if boxes.shape[-1] is not None and boxes.shape[-1] != 4: + raise ValueError( + "Expected `boxes` to be a Tensor with a final dimension of " + f"`4`. Instead, got `boxes.shape={boxes.shape}`." + ) + if images is not None and image_shape is not None: + raise ValueError( + "convert_format() expects either `images` or `image_shape`, but " + f"not both. Received images={images} image_shape={image_shape}" + ) + + _validate_image_shape(image_shape) + + source = source.lower() + target = target.lower() + if source not in TO_XYXY_CONVERTERS: + raise ValueError( + "`convert_format()` received an unsupported format for the " + "argument `source`. `source` should be one of " + f"{TO_XYXY_CONVERTERS.keys()}. Got source={source}" + ) + if target not in FROM_XYXY_CONVERTERS: + raise ValueError( + "`convert_format()` received an unsupported format for the " + "argument `target`. `target` should be one of " + f"{FROM_XYXY_CONVERTERS.keys()}. Got target={target}" + ) + + boxes = ops.cast(boxes, dtype) + if source == target: + return boxes + + # rel->rel conversions should not require images + if source.startswith("rel") and target.startswith("rel"): + source = source.replace("rel_", "", 1) + target = target.replace("rel_", "", 1) + + boxes, images, squeeze = _format_inputs(boxes, images) + to_xyxy_fn = TO_XYXY_CONVERTERS[source] + from_xyxy_fn = FROM_XYXY_CONVERTERS[target] + + try: + in_xyxy = to_xyxy_fn(boxes, images=images, image_shape=image_shape) + result = from_xyxy_fn(in_xyxy, images=images, image_shape=image_shape) + except RequiresImagesException: + raise ValueError( + "convert_format() must receive `images` or `image_shape` when " + "transforming between relative and absolute formats." + f"convert_format() received source=`{format}`, target=`{format}, " + f"but images={images} and image_shape={image_shape}." + ) + + return _format_outputs(result, squeeze) + + +def _format_inputs(boxes, images): + boxes_rank = len(boxes.shape) + if boxes_rank > 3: + raise ValueError( + "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " + f"len(boxes.shape)={boxes_rank}" + ) + boxes_includes_batch = boxes_rank == 3 + # Determine if images needs an expand_dims() call + if images is not None: + images_rank = len(images.shape) + if images_rank > 4: + raise ValueError( + "Expected len(images.shape)=2, or len(images.shape)=3, got " + f"len(images.shape)={images_rank}" + ) + images_include_batch = images_rank == 4 + if boxes_includes_batch != images_include_batch: + raise ValueError( + "convert_format() expects both boxes and images to be batched, " + "or both boxes and images to be unbatched. Received " + f"len(boxes.shape)={boxes_rank}, " + f"len(images.shape)={images_rank}. Expected either " + "len(boxes.shape)=2 AND len(images.shape)=3, or " + "len(boxes.shape)=3 AND len(images.shape)=4." + ) + if not images_include_batch: + images = ops.expand_dims(images, axis=0) + + if not boxes_includes_batch: + return ops.expand_dims(boxes, axis=0), images, True + return boxes, images, False + + +def _validate_image_shape(image_shape): + # Escape early if image_shape is None and skip validation. + if image_shape is None: + return + # tuple/list + if isinstance(image_shape, (tuple, list)): + if len(image_shape) != 3: + raise ValueError( + "image_shape should be of length 3, but got " + f"image_shape={image_shape}" + ) + return + + # tensor + if ops.is_tensor(image_shape): + if len(image_shape.shape) > 1: + raise ValueError( + "image_shape.shape should be (3), but got " + f"image_shape.shape={image_shape.shape}" + ) + if image_shape.shape[0] != 3: + raise ValueError( + "image_shape.shape should be (3), but got " + f"image_shape.shape={image_shape.shape}" + ) + return + + # Warn about failure cases + raise ValueError( + "Expected image_shape to be either a tuple, list, Tensor. " + f"Received image_shape={image_shape}" + ) + + +def _format_outputs(boxes, squeeze): + if squeeze: + return ops.squeeze(boxes, axis=0) + return boxes + + +def _image_shape(images, image_shape, boxes): + if images is None and image_shape is None: + raise RequiresImagesException() + + if image_shape is None: + if not isinstance(images, tf.RaggedTensor): + image_shape = ops.shape(images) + height, width = image_shape[1], image_shape[2] + else: + height = ops.reshape(images.row_lengths(), (-1, 1)) + width = ops.reshape(ops.max(images.row_lengths(axis=2), 1), (-1, 1)) + height = ops.expand_dims(height, axis=-1) + width = ops.expand_dims(width, axis=-1) + else: + height, width = image_shape[0], image_shape[1] + return ops.cast(height, boxes.dtype), ops.cast(width, boxes.dtype) diff --git a/keras_nlp/src/bounding_box/converters_test.py b/keras_nlp/src/bounding_box/converters_test.py new file mode 100644 index 0000000000..f6f3adfa17 --- /dev/null +++ b/keras_nlp/src/bounding_box/converters_test.py @@ -0,0 +1,365 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized +from keras import backend + +from keras_nlp.src.bounding_box import converters +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.bounding_box import to_ragged +from keras_nlp.src.tests.test_case import TestCase + + +class ConvertersTestCase(TestCase): + def setUp(self): + xyxy_box = np.array( + [[[10, 20, 110, 120], [20, 30, 120, 130]]], dtype="float32" + ) + yxyx_box = np.array( + [[[20, 10, 120, 110], [30, 20, 130, 120]]], dtype="float32" + ) + rel_xyxy_box = np.array( + [[[0.01, 0.02, 0.11, 0.12], [0.02, 0.03, 0.12, 0.13]]], + dtype="float32", + ) + rel_xyxy_box_ragged_images = np.array( + [[[0.10, 0.20, 1.1, 1.20], [0.40, 0.6, 2.40, 2.6]]], dtype="float32" + ) + rel_yxyx_box = np.array( + [[[0.02, 0.01, 0.12, 0.11], [0.03, 0.02, 0.13, 0.12]]], + dtype="float32", + ) + rel_yxyx_box_ragged_images = np.array( + [[[0.2, 0.1, 1.2, 1.1], [0.6, 0.4, 2.6, 2.4]]], dtype="float32" + ) + center_xywh_box = np.array( + [[[60, 70, 100, 100], [70, 80, 100, 100]]], dtype="float32" + ) + xywh_box = np.array( + [[[10, 20, 100, 100], [20, 30, 100, 100]]], dtype="float32" + ) + rel_xywh_box = np.array( + [[[0.01, 0.02, 0.1, 0.1], [0.02, 0.03, 0.1, 0.1]]], dtype="float32" + ) + rel_xywh_box_ragged_images = np.array( + [[[0.1, 0.2, 1, 1], [0.4, 0.6, 2, 2]]], dtype="float32" + ) + + self.ragged_images = tf.ragged.constant( + [ + np.ones(shape=[100, 100, 3]), + np.ones(shape=[50, 50, 3]), + ], # 2 images + ragged_rank=2, + ) + + self.images = np.ones([2, 1000, 1000, 3]) + + self.ragged_classes = tf.ragged.constant([[0], [0]], dtype="float32") + + self.boxes = { + "xyxy": xyxy_box, + "center_xywh": center_xywh_box, + "rel_xywh": rel_xywh_box, + "xywh": xywh_box, + "rel_xyxy": rel_xyxy_box, + "yxyx": yxyx_box, + "rel_yxyx": rel_yxyx_box, + } + + self.boxes_ragged_images = { + "xyxy": xyxy_box, + "center_xywh": center_xywh_box, + "rel_xywh": rel_xywh_box_ragged_images, + "xywh": xywh_box, + "rel_xyxy": rel_xyxy_box_ragged_images, + "yxyx": yxyx_box, + "rel_yxyx": rel_yxyx_box_ragged_images, + } + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + def test_converters(self, source, target): + source, target + source_box = self.boxes[source] + target_box = self.boxes[target] + + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converters_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + images=self.ragged_images, + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + def test_converters_unbatched(self, source, target): + source_box = self.boxes[source][0] + target_box = self.boxes[target][0] + + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images[0] + ), + target_box, + ) + + def test_raises_with_different_image_rank(self): + source_box = self.boxes["xyxy"][0] + with self.assertRaises(ValueError): + converters.convert_format( + source_box, source="xyxy", target="xywh", images=self.images + ) + + def test_without_images(self): + source_box = self.boxes["xyxy"] + target_box = self.boxes["xywh"] + self.assertAllClose( + converters.convert_format(source_box, source="xyxy", target="xywh"), + target_box, + ) + + def test_rel_to_rel_without_images(self): + source_box = self.boxes["rel_xyxy"] + target_box = self.boxes["rel_yxyx"] + self.assertAllClose( + converters.convert_format( + source_box, source="rel_xyxy", target="rel_yxyx" + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box(self, source, target): + source_box = _raggify(self.boxes[source]) + target_box = _raggify(self.boxes[target]) + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + images=self.ragged_images, + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box_with_image_shape(self, source, target): + source_box = _raggify(self.boxes[source]) + target_box = _raggify(self.boxes[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + image_shape=(1000, 1000, 3), + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_dense_bounding_box_with_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + source_bounding_boxes = { + "boxes": source_box, + "classes": self.ragged_classes, + } + source_bounding_boxes = to_dense.to_dense(source_bounding_boxes) + + result_bounding_boxes = converters.convert_format( + source_bounding_boxes, + source=source, + target=target, + images=self.ragged_images, + ) + result_bounding_boxes = to_ragged.to_ragged(result_bounding_boxes) + + self.assertAllClose( + result_bounding_boxes["boxes"], + target_box, + ) + + +def _raggify(tensor): + tensor = tf.squeeze(tensor, axis=0) + tensor = tf.RaggedTensor.from_row_lengths(tensor, [1, 1]) + return tensor diff --git a/keras_nlp/src/bounding_box/to_dense.py b/keras_nlp/src/bounding_box/to_dense.py new file mode 100644 index 0000000000..3c42d09f4f --- /dev/null +++ b/keras_nlp/src/bounding_box/to_dense.py @@ -0,0 +1,95 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras_nlp.src.bounding_box.validate_format as validate_format +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +def _box_shape(batched, boxes_shape, max_boxes): + # ensure we dont drop the final axis in RaggedTensor mode + if max_boxes is None: + shape = list(boxes_shape) + shape[-1] = 4 + return shape + if batched: + return [None, max_boxes, 4] + return [max_boxes, 4] + + +def _classes_shape(batched, classes_shape, max_boxes): + if max_boxes is None: + return None + if batched: + return [None, max_boxes] + classes_shape[2:] + return [max_boxes] + classes_shape[2:] + + +@keras_nlp_export("keras_nlp.bounding_box.to_dense") +def to_dense(bounding_boxes, max_boxes=None, default_value=-1): + """to_dense converts bounding boxes to Dense tensors + + Args: + bounding_boxes: bounding boxes in KerasCV dictionary format. + max_boxes: the maximum number of boxes, used to pad tensors to a given + shape. This can be used to make object detection pipelines TPU + compatible. + default_value: the default value to pad bounding boxes with. defaults + to -1. + """ + info = validate_format.validate_format(bounding_boxes) + + # guards against errors in metrics regarding modification of inputs. + # also guards against unexpected behavior when modifying downstream + bounding_boxes = bounding_boxes.copy() + + # Already running in masked mode + if not info["ragged"]: + # even if already ragged, still copy the dictionary for API consistency + return bounding_boxes + + if isinstance(bounding_boxes["classes"], tf.RaggedTensor): + bounding_boxes["classes"] = bounding_boxes["classes"].to_tensor( + default_value=default_value, + shape=_classes_shape( + info["is_batched"], bounding_boxes["classes"].shape, max_boxes + ), + ) + + if isinstance(bounding_boxes["boxes"], tf.RaggedTensor): + bounding_boxes["boxes"] = bounding_boxes["boxes"].to_tensor( + default_value=default_value, + shape=_box_shape( + info["is_batched"], bounding_boxes["boxes"].shape, max_boxes + ), + ) + + if "confidence" in bounding_boxes: + if isinstance(bounding_boxes["confidence"], tf.RaggedTensor): + bounding_boxes["confidence"] = bounding_boxes[ + "confidence" + ].to_tensor( + default_value=default_value, + shape=_classes_shape( + info["is_batched"], + bounding_boxes["confidence"].shape, + max_boxes, + ), + ) + + return bounding_boxes diff --git a/keras_nlp/src/bounding_box/to_dense_test.py b/keras_nlp/src/bounding_box/to_dense_test.py new file mode 100644 index 0000000000..4bb795659b --- /dev/null +++ b/keras_nlp/src/bounding_box/to_dense_test.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tensorflow as tf +from keras import backend + +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.tests.test_case import TestCase + + +class ToDenseTest(TestCase): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converts_to_dense(self): + bounding_boxes = { + "boxes": tf.ragged.constant( + [[[0, 0, 1, 1]], [[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 1, 1]]] + ), + "classes": tf.ragged.constant([[0], [1, 2, 3]]), + } + bounding_boxes = to_dense.to_dense(bounding_boxes) + self.assertEqual(bounding_boxes["boxes"].shape, [2, 3, 4]) + self.assertEqual(bounding_boxes["classes"].shape, [2, 3]) diff --git a/keras_nlp/src/bounding_box/to_ragged.py b/keras_nlp/src/bounding_box/to_ragged.py new file mode 100644 index 0000000000..2ebd4a00f4 --- /dev/null +++ b/keras_nlp/src/bounding_box/to_ragged.py @@ -0,0 +1,99 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +import keras_nlp.src.bounding_box.validate_format as validate_format +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_nlp_export("keras_nlp.bounding_box.to_ragged") +def to_ragged(bounding_boxes, sentinel=-1, dtype="float32"): + """converts a Dense padded bounding box `tf.Tensor` to a `tf.RaggedTensor`. + + Bounding boxes are ragged tensors in most use cases. Converting them to a + dense tensor makes it easier to work with Tensorflow ecosystem. + This function can be used to filter out the masked out bounding boxes by + checking for padded sentinel value of the class_id axis of the + bounding_boxes. + + Example: + ```python + bounding_boxes = { + "boxes": tf.constant([[2, 3, 4, 5], [0, 1, 2, 3]]), + "classes": tf.constant([[-1, 1]]), + } + bounding_boxes = bounding_box.to_ragged(bounding_boxes) + print(bounding_boxes) + # { + # "boxes": [[0, 1, 2, 3]], + # "classes": [[1]] + # } + ``` + + Args: + bounding_boxes: a Tensor of bounding boxes. May be batched, or + unbatched. + sentinel: The value indicating that a bounding box does not exist at the + current index, and the corresponding box is padding, defaults to -1. + dtype: the data type to use for the underlying Tensors. + Returns: + dictionary of `tf.RaggedTensor` or 'tf.Tensor' containing the filtered + bounding boxes. + """ + if keras.config.backend() != "tensorflow": + raise NotImplementedError( + "`bounding_box.to_ragged` was called using a backend which does " + "not support ragged tensors. " + f"Current backend: {keras.backend.backend()}." + ) + + info = validate_format.validate_format(bounding_boxes) + + if info["ragged"]: + return bounding_boxes + + boxes = bounding_boxes.get("boxes") + classes = bounding_boxes.get("classes") + confidence = bounding_boxes.get("confidence", None) + + mask = classes != sentinel + + boxes = tf.ragged.boolean_mask(boxes, mask) + classes = tf.ragged.boolean_mask(classes, mask) + if confidence is not None: + confidence = tf.ragged.boolean_mask(confidence, mask) + + if isinstance(boxes, tf.Tensor): + boxes = tf.RaggedTensor.from_tensor(boxes) + + if isinstance(classes, tf.Tensor) and len(classes.shape) > 1: + classes = tf.RaggedTensor.from_tensor(classes) + + if confidence is not None: + if isinstance(confidence, tf.Tensor) and len(confidence.shape) > 1: + confidence = tf.RaggedTensor.from_tensor(confidence) + + result = bounding_boxes.copy() + result["boxes"] = tf.cast(boxes, dtype) + result["classes"] = tf.cast(classes, dtype) + + if confidence is not None: + result["confidence"] = tf.cast(confidence, dtype) + + return result diff --git a/keras_nlp/src/bounding_box/to_ragged_test.py b/keras_nlp/src/bounding_box/to_ragged_test.py new file mode 100644 index 0000000000..cbe5146d11 --- /dev/null +++ b/keras_nlp/src/bounding_box/to_ragged_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import backend + +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.bounding_box import to_ragged +from keras_nlp.src.tests.test_case import TestCase + + +class ToRaggedTest(TestCase): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converts_to_ragged(self): + bounding_boxes = { + "boxes": np.array( + [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] + ), + "classes": np.array([[-1, -1], [-1, 1]]), + "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), + } + bounding_boxes = to_ragged.to_ragged(bounding_boxes) + + self.assertEqual(bounding_boxes["boxes"][1].shape, [1, 4]) + self.assertEqual(bounding_boxes["classes"][1].shape, [1]) + self.assertEqual( + bounding_boxes["confidence"][1].shape, + [ + 1, + ], + ) + + self.assertEqual(bounding_boxes["classes"][0].shape, [0]) + self.assertEqual(bounding_boxes["boxes"][0].shape, [0, 4]) + self.assertEqual( + bounding_boxes["confidence"][0].shape, + [ + 0, + ], + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_round_trip(self): + original = { + "boxes": np.array( + [ + [[0, 0, 0, 0], [-1, -1, -1, -1]], + [[-1, -1, -1, -1], [-1, -1, -1, -1]], + ] + ), + "classes": np.array([[1, -1], [-1, -1]]), + "confidence": np.array([[0.5, -1], [-1, -1]]), + } + bounding_boxes = to_ragged.to_ragged(original) + bounding_boxes = to_dense.to_dense(bounding_boxes, max_boxes=2) + + self.assertEqual(bounding_boxes["boxes"][1].shape, [2, 4]) + self.assertEqual(bounding_boxes["classes"][1].shape, [2]) + self.assertEqual(bounding_boxes["classes"][0].shape, [2]) + self.assertEqual(bounding_boxes["boxes"][0].shape, [2, 4]) + self.assertEqual(bounding_boxes["confidence"][0].shape, [2]) + + self.assertAllEqual(bounding_boxes["boxes"], original["boxes"]) + self.assertAllEqual(bounding_boxes["classes"], original["classes"]) + self.assertAllEqual( + bounding_boxes["confidence"], original["confidence"] + ) + + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="Only applies to backends which don't support raggeds", + ) + def test_backend_without_raggeds_throws(self): + bounding_boxes = { + "boxes": np.array( + [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] + ), + "classes": np.array([[-1, -1], [-1, 1]]), + "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), + } + + with self.assertRaisesRegex(NotImplementedError, "support ragged"): + to_ragged.to_ragged(bounding_boxes) diff --git a/keras_nlp/src/bounding_box/validate_format.py b/keras_nlp/src/bounding_box/validate_format.py new file mode 100644 index 0000000000..51fb310807 --- /dev/null +++ b/keras_nlp/src/bounding_box/validate_format.py @@ -0,0 +1,99 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_nlp_export("keras_nlp.bounding_box.validate_format") +def validate_format(bounding_boxes, variable_name="bounding_boxes"): + """validates that a given set of bounding boxes complies with KerasNLP + format. + + For a set of bounding boxes to be valid it must satisfy the following + conditions: + - `bounding_boxes` must be a dictionary + - contains keys `"boxes"` and `"classes"` + - each entry must have matching first two dimensions; representing the batch + axis and the number of boxes per image axis. + - either both `"boxes"` and `"classes"` are batched, or both are unbatched. + + Additionally, one of the following must be satisfied: + - `"boxes"` and `"classes"` are both Ragged + - `"boxes"` and `"classes"` are both Dense + - `"boxes"` and `"classes"` are unbatched + + Args: + bounding_boxes: dictionary of bounding boxes according to KerasCV + format. + + Raises: + ValueError if any of the above conditions are not met + """ + if not isinstance(bounding_boxes, dict): + raise ValueError( + f"Expected `{variable_name}` to be a dictionary, got " + f"`{variable_name}={bounding_boxes}`." + ) + if not all([x in bounding_boxes for x in ["boxes", "classes"]]): + raise ValueError( + f"Expected `{variable_name}` to be a dictionary containing keys " + "`'classes'` and `'boxes'`. Got " + f"`{variable_name}.keys()={bounding_boxes.keys()}`." + ) + + boxes = bounding_boxes.get("boxes") + classes = bounding_boxes.get("classes") + info = {} + + is_batched = len(boxes.shape) == 3 + info["is_batched"] = is_batched + info["ragged"] = isinstance(boxes, tf.RaggedTensor) + + if not is_batched: + if boxes.shape[:1] != classes.shape[:1]: + raise ValueError( + "Expected `boxes` and `classes` to have matching dimensions " + "on the first axis when operating in unbatched mode. Got " + f"`boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." + ) + + info["classes_one_hot"] = len(classes.shape) == 2 + # No Ragged checks needed in unbatched mode. + return info + + info["classes_one_hot"] = len(classes.shape) == 3 + + if isinstance(boxes, tf.RaggedTensor) != isinstance( + classes, tf.RaggedTensor + ): + raise ValueError( + "Either both `boxes` and `classes` " + "should be Ragged, or neither should be ragged." + f" Got `type(boxes)={type(boxes)}`, type(classes)={type(classes)}." + ) + + # Batched mode checks + if boxes.shape[:2] != classes.shape[:2]: + raise ValueError( + "Expected `boxes` and `classes` to have matching dimensions " + "on the first two axes when operating in batched mode. " + f"Got `boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." + ) + + return info From 9289ab79b8fa4cd00926b8536bad07fd755714d6 Mon Sep 17 00:00:00 2001 From: Usha Rengaraju <34335028+ushareng@users.noreply.github.com> Date: Thu, 29 Aug 2024 04:26:16 +0530 Subject: [PATCH 20/33] mobilenet_v3 added in keras-nlp (#1782) * mobilenet_v3 added in keras-nlp * minor bug fixed in mobilenet_v3_backbone * formatting corrected * refactoring backbone * correct_pad_downsample method added * refactoring backbone * parameters updated * Testcaseupdated, expected output shape corrected * code formatted with black * testcase updated * refactoring and description added * comments updated * added mobilenet v1 and v2 * merge conflict resolved * version arg removed, and config options added * input_shape changed to image_shape in arg * config updated * input shape corrected * comments resolved * activation function format changed * minor bug fixed * minor bug fixed * added vision_backbone_test * channel_first bug resolved * channel_first cases working * comments resolved * formatting fixed * refactoring --------- Co-authored-by: ushareng --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/mobilenet/__init__.py | 13 + .../models/mobilenet/mobilenet_backbone.py | 530 ++++++++++++++++++ .../mobilenet/mobilenet_backbone_test.py | 58 ++ .../mobilenet/mobilenet_image_classifier.py | 114 ++++ .../mobilenet_image_classifier_test.py | 71 +++ 6 files changed, 790 insertions(+) create mode 100644 keras_nlp/src/models/mobilenet/__init__.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_backbone.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index c6d8ed7d32..17b00c1f05 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -171,6 +171,10 @@ from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( MiTImageClassifier, ) +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mobilenet/__init__.py b/keras_nlp/src/models/mobilenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mobilenet/mobilenet_backbone.py b/keras_nlp/src/models/mobilenet/mobilenet_backbone.py new file mode 100644 index 0000000000..4054b6d76f --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_backbone.py @@ -0,0 +1,530 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_EPSILON = 1e-3 +BN_MOMENTUM = 0.999 + + +@keras_nlp_export("keras_nlp.models.MobileNetBackbone") +class MobileNetBackbone(Backbone): + """Instantiates the MobileNet architecture. + + MobileNet is a lightweight convolutional neural network (CNN) + optimized for mobile and edge devices, striking a balance between + accuracy and efficiency. By employing depthwise separable convolutions + and techniques like Squeeze-and-Excitation (SE) blocks, + MobileNet models are highly suitable for real-time applications on + resource-constrained devices. + + References: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) + (ICCV 2019) + + Args: + stackwise_expansion: list of ints or floats, the expansion ratio for + each inverted residual block in the model. + stackwise_num_filters: list of ints, number of filters for each inverted + residual block in the model. + stackwise_kernel_size: list of ints, kernel size for each inverted + residual block in the model. + stackwise_num_strides: list of ints, stride length for each inverted + residual block in the model. + stackwise_se_ratio: se ratio for each inverted residual block in the + model. 0 if dont want to add Squeeze and Excite layer. + stackwise_activation: list of activation functions, for each inverted + residual block in the model. + include_rescaling: bool, whether to rescale the inputs. If set to True, + inputs will be passed through a `Rescaling(scale=1 / 255)` + layer. + image_shape: optional shape tuple, defaults to (224, 224, 3). + depth_multiplier: float, controls the width of the network. + - If `depth_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `depth_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `depth_multiplier` = 1, default number of filters from the paper + are used at each layer. + input_num_filters: number of filters in first convolution layer + output_num_filters: specifies whether to add conv and batch_norm in the end, + if set to None, it will not add these layers in the end. + 'None' for MobileNetV1 + input_activation: activation function to be used in the input layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + output_activation: activation function to be used in the output layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + inverted_res_block: whether to use inverted residual blocks or not, + 'False' for MobileNetV1, + 'True' for MobileNetV2 and MobileNetV3 + + + Example: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Randomly initialized backbone with a custom config + model = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[3, 3, 5], + stackwise_num_strides=[2, 2, 1], + stackwise_se_ratio=[0.25, None, 0.25], + stackwise_activation=["relu", "relu6", "hard_swish"], + include_rescaling=False, + output_num_filters=1280, + input_activation='hard_swish', + output_activation='hard_swish', + inverted_res_block=True, + + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_expansion, + stackwise_num_filters, + stackwise_kernel_size, + stackwise_num_strides, + stackwise_se_ratio, + stackwise_activation, + include_rescaling, + output_num_filters, + inverted_res_block, + image_shape=(224, 224, 3), + input_activation="hard_swish", + output_activation="hard_swish", + depth_multiplier=1.0, + input_num_filters=16, + **kwargs, + ): + # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + input_num_filters = adjust_channels(input_num_filters) + x = keras.layers.Conv2D( + input_num_filters, + kernel_size=3, + strides=(2, 2), + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="input_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="input_batch_norm", + )(x) + x = keras.layers.Activation(input_activation)(x) + + for stack_index in range(len(stackwise_num_filters)): + filters = adjust_channels( + (stackwise_num_filters[stack_index]) * depth_multiplier + ) + + if inverted_res_block: + x = apply_inverted_res_block( + x, + expansion=stackwise_expansion[stack_index], + filters=filters, + kernel_size=stackwise_kernel_size[stack_index], + stride=stackwise_num_strides[stack_index], + se_ratio=(stackwise_se_ratio[stack_index]), + activation=stackwise_activation[stack_index], + expansion_index=stack_index, + ) + else: + x = apply_depthwise_conv_block( + x, + filters=filters, + kernel_size=3, + stride=stackwise_num_strides[stack_index], + depth_multiplier=depth_multiplier, + block_id=stack_index, + ) + + if output_num_filters is not None: + last_conv_ch = adjust_channels(x.shape[channel_axis] * 6) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="output_batch_norm", + )(x) + x = keras.layers.Activation(output_activation)(x) + + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.stackwise_expansion = stackwise_expansion + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_kernel_size = stackwise_kernel_size + self.stackwise_num_strides = stackwise_num_strides + self.stackwise_se_ratio = stackwise_se_ratio + self.stackwise_activation = stackwise_activation + self.include_rescaling = include_rescaling + self.depth_multiplier = depth_multiplier + self.input_num_filters = input_num_filters + self.output_num_filters = output_num_filters + self.input_activation = keras.activations.get(input_activation) + self.output_activation = keras.activations.get(output_activation) + self.inverted_res_block = inverted_res_block + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_kernel_size": self.stackwise_kernel_size, + "stackwise_num_strides": self.stackwise_num_strides, + "stackwise_se_ratio": self.stackwise_se_ratio, + "stackwise_activation": self.stackwise_activation, + "include_rescaling": self.include_rescaling, + "image_shape": self.image_shape, + "depth_multiplier": self.depth_multiplier, + "input_num_filters": self.input_num_filters, + "output_num_filters": self.output_num_filters, + "input_activation": keras.activations.serialize( + activation=self.input_activation + ), + "output_activation": keras.activations.serialize( + activation=self.output_activation + ), + "inverted_res_block": self.inverted_res_block, + } + ) + return config + + +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x + + +def apply_inverted_res_block( + x, + expansion, + filters, + kernel_size, + stride, + se_ratio, + activation, + expansion_index, +): + """An Inverted Residual Block. + + Args: + x: input tensor. + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + expansion_index: integer, a unique identification if you want to use + expanded convolutions. If greater than 0, an additional Conv+BN + layer is added after the expanded convolutional layer. + + Returns: + the updated input tensor. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + activation = keras.activations.get(activation) + shortcut = x + prefix = "expanded_conv_" + infilters = x.shape[channel_axis] + + if expansion_index > 0: + prefix = f"expanded_conv_{expansion_index}_" + + x = keras.layers.Conv2D( + adjust_channels(infilters * expansion), + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "expand", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "expand_BatchNorm", + )(x) + x = keras.layers.Activation(activation=activation)(x) + + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + name=prefix + "depthwise_pad", + )(x) + + x = keras.layers.DepthwiseConv2D( + kernel_size, + strides=stride, + padding="same" if stride == 1 else "valid", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "depthwise", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "depthwise_BatchNorm", + )(x) + x = keras.layers.Activation(activation=activation)(x) + + if se_ratio: + se_filters = adjust_channels(infilters * expansion) + x = SqueezeAndExcite2D( + input=x, + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "project", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "project_BatchNorm", + )(x) + + if stride == 1 and infilters == filters: + x = keras.layers.Add(name=prefix + "Add")([shortcut, x]) + + return x + + +def apply_depthwise_conv_block( + x, + filters, + kernel_size=3, + depth_multiplier=1, + stride=1, + block_id=1, +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + depth_multiplier: controls the width of the network. + - If `depth_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `depth_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `depth_multiplier` = 1, default number of filters from the + paper are used at each layer. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + name="conv_pad_%d" % block_id, + )(x) + + x = keras.layers.DepthwiseConv2D( + kernel_size, + strides=stride, + padding="same" if stride == 1 else "valid", + data_format=keras.config.image_data_format(), + depth_multiplier=depth_multiplier, + use_bias=False, + name="depthwise_%d" % block_id, + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="depthwise_BatchNorm_%d" % block_id, + )(x) + x = keras.layers.ReLU(6.0)(x) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="conv_%d" % block_id, + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="BatchNorm_%d" % block_id, + )(x) + return keras.layers.ReLU(6.0)(x) + + +def SqueezeAndExcite2D( + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", +): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + """ + if not bottleneck_filters: + bottleneck_filters = filters // 4 + + x = keras.layers.GlobalAveragePooling2D(keepdims=True)(input) + + x = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=squeeze_activation, + )(x) + x = keras.layers.Conv2D( + filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=excite_activation, + )(x) + + x = ops.multiply(x, input) + return x + + +def correct_pad_downsample(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) diff --git a/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py b/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py new file mode 100644 index 0000000000..80225abe04 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class MobileNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_expansion": [1, 4, 6], + "stackwise_num_filters": [4, 8, 16], + "stackwise_kernel_size": [3, 3, 5], + "stackwise_num_strides": [2, 2, 1], + "stackwise_se_ratio": [0.25, None, 0.25], + "stackwise_activation": ["relu", "relu", "hard_swish"], + "include_rescaling": False, + "output_num_filters": 1280, + "input_activation": "hard_swish", + "output_activation": "hard_swish", + "inverted_res_block": True, + "input_num_filters": 16, + "image_shape": (224, 224, 3), + "depth_multiplier": 1, + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 28, 28, 96), + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py new file mode 100644 index 0000000000..3e08f3482c --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py @@ -0,0 +1,114 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + + +@keras_nlp_export("keras_nlp.models.MobileNetImageClassifier") +class MobileNetImageClassifier(ImageClassifier): + """MobileNetV3 image classifier task model. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Args: + backbone: A `keras_nlp.models.MobileNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MobileNetImageClassifier.from_preset( + "mobilenet_v3_small_imagenet") + classifier.predict(images) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + model = MobileNetBackbone( + stackwise_expansion = [1, 4, 6], + stackwise_filters = [4, 8, 16], + stackwise_kernel_size = [3, 3, 5], + stackwise_stride = [2, 2, 1], + stackwise_se_ratio = [ 0.25, None, 0.25], + stackwise_activation = ["relu", "relu", "hard_swish"], + include_rescaling = False, + output_filter=1280, + activation="hard_swish", + inverted_res_block=True, + ) + classifier = keras_nlp.models.MobileNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MobileNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py new file mode 100644 index 0000000000..29d00e6d24 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -0,0 +1,71 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MobileNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[3, 3, 5], + stackwise_num_strides=[2, 2, 1], + stackwise_se_ratio=[0.25, None, 0.25], + stackwise_activation=["relu", "relu", "hard_swish"], + include_rescaling=False, + output_num_filters=1280, + input_activation="hard_swish", + output_activation="hard_swish", + inverted_res_block=True, + input_num_filters=16, + image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 09f470f726417faee597dff54e636afc3854aeed Mon Sep 17 00:00:00 2001 From: pkgoogle <132095473+pkgoogle@users.noreply.github.com> Date: Wed, 28 Aug 2024 22:56:37 +0000 Subject: [PATCH 21/33] Pkgoogle/efficient net migration (#1778) * migrating efficientnet models to keras-hub * merging changes from other sources * autoformatting pass * initial consolidation of efficientnet_backbone * most updates and removing separate implementation * cleanup, autoformatting, keras generalization * removed layer examples outside of effiicient net * many, mainly documentation changes, small test fixes --- keras_nlp/api/models/__init__.py | 3 + keras_nlp/src/models/efficientnet/__init__.py | 13 + .../efficientnet/efficientnet_backbone.py | 569 ++++++++++++++++++ .../efficientnet_backbone_test.py | 146 +++++ .../src/models/efficientnet/fusedmbconv.py | 229 +++++++ .../models/efficientnet/fusedmbconv_test.py | 46 ++ keras_nlp/src/models/efficientnet/mbconv.py | 238 ++++++++ .../src/models/efficientnet/mbconv_test.py | 44 ++ 8 files changed, 1288 insertions(+) create mode 100644 keras_nlp/src/models/efficientnet/__init__.py create mode 100644 keras_nlp/src/models/efficientnet/efficientnet_backbone.py create mode 100644 keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py create mode 100644 keras_nlp/src/models/efficientnet/fusedmbconv.py create mode 100644 keras_nlp/src/models/efficientnet/fusedmbconv_test.py create mode 100644 keras_nlp/src/models/efficientnet/mbconv.py create mode 100644 keras_nlp/src/models/efficientnet/mbconv_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 17b00c1f05..061fbb90b6 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -96,6 +96,9 @@ from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) +from keras_nlp.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone from keras_nlp.src.models.electra.electra_preprocessor import ( ElectraPreprocessor, diff --git a/keras_nlp/src/models/efficientnet/__init__.py b/keras_nlp/src/models/efficientnet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/efficientnet/efficientnet_backbone.py b/keras_nlp/src/models/efficientnet/efficientnet_backbone.py new file mode 100644 index 0000000000..2d7940d6df --- /dev/null +++ b/keras_nlp/src/models/efficientnet/efficientnet_backbone.py @@ -0,0 +1,569 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.efficientnet.fusedmbconv import FusedMBConvBlock +from keras_nlp.src.models.efficientnet.mbconv import MBConvBlock +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone + + +@keras_nlp_export("keras_nlp.models.EfficientNetBackbone") +class EfficientNetBackbone(FeaturePyramidBackbone): + """An EfficientNet backbone model. + + This class encapsulates the architectures for both EfficientNetV1 and + EfficientNetV2. EfficientNetV2 uses Fused-MBConv Blocks and Neural + Architecture Search (NAS) to make models sizes much smaller while still + improving overall model quality. + + References: + - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks] + (https://arxiv.org/abs/1905.11946) (ICML 2019) + - [Based on the original keras.applications EfficientNet] + (https://github.com/keras-team/keras/blob/master/keras/applications/efficientnet.py) + - [EfficientNetV2: Smaller Models and Faster Training] + (https://arxiv.org/abs/2104.00298) (ICML 2021) + + Args: + width_coefficient: float, scaling coefficient for network width. + depth_coefficient: float, scaling coefficient for network depth. + dropout: float, dropout rate at skip connections. The default + value is set to 0.2. + depth_divisor: integer, a unit of network width. The default value is + set to 8. + activation: activation function to use between each convolutional layer. + input_shape: optional shape tuple, it should have exactly 3 input + channels. + stackwise_kernel_sizes: list of ints, the kernel sizes used for each + conv block. + stackwise_num_repeats: list of ints, number of times to repeat each + conv block. + stackwise_input_filters: list of ints, number of input filters for + each conv block. + stackwise_output_filters: list of ints, number of output filters for + each stack in the conv blocks model. + stackwise_expansion_ratios: list of floats, expand ratio passed to the + squeeze and excitation blocks. + stackwise_strides: list of ints, stackwise_strides for each conv block. + stackwise_squeeze_and_excite_ratios: list of ints, the squeeze and + excite ratios passed to the squeeze and excitation blocks. + stackwise_block_types: list of strings. Each value is either 'v1', + 'unfused' or 'fused' depending on the desired blocks. 'v1' uses the + original efficientnet block. FusedMBConvBlock is similar to + MBConvBlock, but instead of using a depthwise convolution and a 1x1 + output convolution blocks fused blocks use a single 3x3 convolution + block. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + min_depth: integer, minimum number of filters. Can be None and ignored + if use_depth_divisor_as_min_depth is set to True. + include_initial_padding: bool, whether to include initial zero padding + (as per v1). + use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as + the minimum depth instead of min_depth (as per v1). + cap_round_filter_decrease: bool, whether to cap the max decrease in the + number of filters the rounding process potentially produces + (as per v1). + stem_conv_padding: str, can be 'same' or 'valid'. Padding for the stem. + batch_norm_momentum: float, momentum for the moving average calcualtion + in the batch normalization layers. + + Example: + ```python + # You can customize the EfficientNet architecture: + model = EfficientNetBackbone( + stackwise_kernel_sizes=[3, 3, 3, 3, 3, 3], + stackwise_num_repeats=[2, 4, 4, 6, 9, 15], + stackwise_input_filters=[24, 24, 48, 64, 128, 160], + stackwise_output_filters=[24, 48, 64, 128, 160, 256], + stackwise_expansion_ratios=[1, 4, 4, 4, 6, 6], + stackwise_squeeze_and_excite_ratios=[0.0, 0.0, 0, 0.25, 0.25, 0.25], + stackwise_strides=[1, 2, 2, 2, 1, 2], + stackwise_block_types=[["fused"] * 3 + ["unfused"] * 3], + width_coefficient=1.0, + depth_coefficient=1.0, + include_rescaling=False, + ) + images = np.ones((1, 256, 256, 3)) + outputs = efficientnet.predict(images) + ``` + """ + + def __init__( + self, + *, + width_coefficient, + depth_coefficient, + stackwise_kernel_sizes, + stackwise_num_repeats, + stackwise_input_filters, + stackwise_output_filters, + stackwise_expansion_ratios, + stackwise_squeeze_and_excite_ratios, + stackwise_strides, + stackwise_block_types, + include_rescaling=True, + dropout=0.2, + depth_divisor=8, + min_depth=8, + input_shape=(None, None, 3), + activation="swish", + include_initial_padding=False, + use_depth_divisor_as_min_depth=False, + cap_round_filter_decrease=False, + stem_conv_padding="same", + batch_norm_momentum=0.9, + **kwargs, + ): + img_input = keras.layers.Input(shape=input_shape) + + x = img_input + + if include_rescaling: + # Use common rescaling strategy across keras + x = keras.layers.Rescaling(scale=1.0 / 255.0)(x) + + if include_initial_padding: + x = keras.layers.ZeroPadding2D( + padding=self._correct_pad_downsample(x, 3), + name="stem_conv_pad", + )(x) + + # Build stem + stem_filters = round_filters( + filters=stackwise_input_filters[0], + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + x = keras.layers.Conv2D( + filters=stem_filters, + kernel_size=3, + strides=2, + padding=stem_conv_padding, + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name="stem_conv", + )(x) + + x = keras.layers.BatchNormalization( + momentum=batch_norm_momentum, + name="stem_bn", + )(x) + x = keras.layers.Activation(activation, name="stem_activation")(x) + + # Build blocks + block_id = 0 + blocks = float(sum(stackwise_num_repeats)) + + self._pyramid_outputs = {} + curr_pyramid_level = 1 + + for i in range(len(stackwise_kernel_sizes)): + num_repeats = stackwise_num_repeats[i] + input_filters = stackwise_input_filters[i] + output_filters = stackwise_output_filters[i] + + # Update block input and output filters based on depth multiplier. + input_filters = round_filters( + filters=input_filters, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + output_filters = round_filters( + filters=output_filters, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + repeats = round_repeats( + repeats=num_repeats, + depth_coefficient=depth_coefficient, + ) + strides = stackwise_strides[i] + squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i] + + for j in range(repeats): + # The first block needs to take care of stride and filter size + # increase. + if j > 0: + strides = 1 + input_filters = output_filters + + if strides != 1: + self._pyramid_outputs[f"P{curr_pyramid_level}"] = x + curr_pyramid_level += 1 + + # 97 is the start of the lowercase alphabet. + letter_identifier = chr(j + 97) + stackwise_block_type = stackwise_block_types[i] + block_name = f"block{i + 1}{letter_identifier}_" + if stackwise_block_type == "v1": + x = self._apply_efficientnet_block( + inputs=x, + filters_in=input_filters, + filters_out=output_filters, + kernel_size=stackwise_kernel_sizes[i], + strides=strides, + expand_ratio=stackwise_expansion_ratios[i], + se_ratio=squeeze_and_excite_ratio, + activation=activation, + dropout=dropout * block_id / blocks, + name=block_name, + ) + else: + block = get_conv_constructor(stackwise_block_type)( + input_filters=input_filters, + output_filters=output_filters, + expand_ratio=stackwise_expansion_ratios[i], + kernel_size=stackwise_kernel_sizes[i], + strides=strides, + se_ratio=squeeze_and_excite_ratio, + activation=activation, + dropout=dropout * block_id / blocks, + batch_norm_momentum=batch_norm_momentum, + name=block_name, + ) + x = block(x) + block_id += 1 + + # Build top + top_filters = round_filters( + filters=1280, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + x = keras.layers.Conv2D( + filters=top_filters, + kernel_size=1, + padding="same", + strides=1, + kernel_initializer=conv_kernel_initializer(), + use_bias=False, + name="top_conv", + data_format="channels_last", + )(x) + x = keras.layers.BatchNormalization( + momentum=batch_norm_momentum, + name="top_bn", + )(x) + x = keras.layers.Activation( + activation=activation, name="top_activation" + )(x) + + self._pyramid_outputs[f"P{curr_pyramid_level}"] = x + curr_pyramid_level += 1 + + # Create model. + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.include_rescaling = include_rescaling + self.width_coefficient = width_coefficient + self.depth_coefficient = depth_coefficient + self.dropout = dropout + self.depth_divisor = depth_divisor + self.min_depth = min_depth + self.activation = activation + self.stackwise_kernel_sizes = stackwise_kernel_sizes + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_input_filters = stackwise_input_filters + self.stackwise_output_filters = stackwise_output_filters + self.stackwise_expansion_ratios = stackwise_expansion_ratios + self.stackwise_squeeze_and_excite_ratios = ( + stackwise_squeeze_and_excite_ratios + ) + self.stackwise_strides = stackwise_strides + self.stackwise_block_types = stackwise_block_types + + self.include_initial_padding = include_initial_padding + self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth + self.cap_round_filter_decrease = cap_round_filter_decrease + self.stem_conv_padding = stem_conv_padding + self.batch_norm_momentum = batch_norm_momentum + + def get_config(self): + config = super().get_config() + config.update( + { + "include_rescaling": self.include_rescaling, + "width_coefficient": self.width_coefficient, + "depth_coefficient": self.depth_coefficient, + "dropout": self.dropout, + "depth_divisor": self.depth_divisor, + "min_depth": self.min_depth, + "activation": self.activation, + "input_shape": self.input_shape[1:], + "stackwise_kernel_sizes": self.stackwise_kernel_sizes, + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_input_filters": self.stackwise_input_filters, + "stackwise_output_filters": self.stackwise_output_filters, + "stackwise_expansion_ratios": self.stackwise_expansion_ratios, + "stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios, + "stackwise_strides": self.stackwise_strides, + "stackwise_block_types": self.stackwise_block_types, + "include_initial_padding": self.include_initial_padding, + "use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth, + "cap_round_filter_decrease": self.cap_round_filter_decrease, + "stem_conv_padding": self.stem_conv_padding, + "batch_norm_momentum": self.batch_norm_momentum, + } + ) + return config + + def _correct_pad_downsample(self, inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) + + def _apply_efficientnet_block( + self, + inputs, + filters_in=32, + filters_out=16, + kernel_size=3, + strides=1, + activation="swish", + expand_ratio=1, + se_ratio=0.0, + dropout=0.0, + name="", + ): + """An inverted residual block. + + Args: + inputs: Tensor, The input tensor of the block + filters_in: integer, the number of input filters. + filters_out: integer, the number of output filters. + kernel_size: integer, the dimension of the convolution window. + strides: integer, the stride of the convolution. + activation: activation function to use between each convolutional layer. + expand_ratio: integer, scaling coefficient for the input filters. + se_ratio: float between 0 and 1, fraction to squeeze the input filters. + dropout: float between 0 and 1, fraction of the input units to drop. + name: string, block label. + + Returns: + output tensor for the block. + """ + filters = filters_in * expand_ratio + if expand_ratio != 1: + x = keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + padding="same", + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name=name + "expand_conv", + )(inputs) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "expand_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "expand_activation" + )(x) + else: + x = inputs + + # Depthwise Convolution + if strides == 2: + x = keras.layers.ZeroPadding2D( + padding=self._correct_pad_downsample(x, kernel_size), + name=name + "dwconv_pad", + )(x) + conv_pad = "valid" + else: + conv_pad = "same" + + x = keras.layers.DepthwiseConv2D( + kernel_size=kernel_size, + strides=strides, + padding=conv_pad, + use_bias=False, + depthwise_initializer=conv_kernel_initializer(), + name=name + "dwconv", + )(x) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "dwconv_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "dwconv_activation" + )(x) + + # Squeeze and Excitation phase + if 0 < se_ratio <= 1: + filters_se = max(1, int(filters_in * se_ratio)) + se = keras.layers.GlobalAveragePooling2D(name=name + "se_squeeze")( + x + ) + se_shape = (1, 1, filters) + se = keras.layers.Reshape(se_shape, name=name + "se_reshape")(se) + se = keras.layers.Conv2D( + filters_se, + 1, + padding="same", + activation=activation, + kernel_initializer=conv_kernel_initializer(), + name=name + "se_reduce", + )(se) + se = keras.layers.Conv2D( + filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=conv_kernel_initializer(), + name=name + "se_expand", + )(se) + x = keras.layers.multiply([x, se], name=name + "se_excite") + + # Output phase + x = keras.layers.Conv2D( + filters=filters_out, + kernel_size=1, + strides=1, + padding="same", + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name=name + "project", + )(x) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "project_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "project_activation" + )(x) + + if strides == 1 and filters_in == filters_out: + if dropout > 0: + x = keras.layers.Dropout( + dropout, + noise_shape=(None, 1, 1, 1), + name=name + "drop", + )(x) + x = keras.layers.Add(name=name + "add")([x, inputs]) + + return x + + +def conv_kernel_initializer(scale=2.0): + return keras.initializers.VarianceScaling( + scale=scale, mode="fan_out", distribution="truncated_normal" + ) + + +def round_filters( + filters, + width_coefficient, + min_depth, + depth_divisor, + use_depth_divisor_as_min_depth, + cap_round_filter_decrease, +): + """Round number of filters based on depth multiplier. + + Args: + filters: int, number of filters for Conv layer + width_coefficient: float, denotes the scaling coefficient of network + width + depth_divisor: int, a unit of network width + use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as + the minimum depth instead of min_depth (as per v1) + max_round_filter_decrease: bool, whether to cap the decrease in the + number of filters this process produces (as per v1) + + Returns: + int, new rounded filters value for Conv layer + """ + filters *= width_coefficient + + if use_depth_divisor_as_min_depth: + min_depth = depth_divisor + + new_filters = max( + min_depth, + int(filters + depth_divisor / 2) // depth_divisor * depth_divisor, + ) + + if cap_round_filter_decrease: + # Make sure that round down does not go down by more than 10%. + if new_filters < 0.9 * filters: + new_filters += depth_divisor + + return int(new_filters) + + +def round_repeats(repeats, depth_coefficient): + """Round number of repeats based on depth multiplier. + + Args: + repeats: int, number of repeats of efficientnet block + depth_coefficient: float, denotes the scaling coefficient of network + depth + + Returns: + int, rounded repeats + """ + return int(math.ceil(depth_coefficient * repeats)) + + +def get_conv_constructor(conv_type): + if conv_type == "unfused": + return MBConvBlock + elif conv_type == "fused": + return FusedMBConvBlock + else: + raise ValueError( + "Expected `conv_type` to be " + "one of 'unfused', 'fused', but got " + f"`conv_type={conv_type}`" + ) diff --git a/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py b/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py new file mode 100644 index 0000000000..8705ed7af1 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py @@ -0,0 +1,146 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +import pytest +from absl.testing import parameterized + +from keras_nlp.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class EfficientNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3], + "stackwise_num_repeats": [2, 4, 4, 6, 9, 15], + "stackwise_input_filters": [24, 24, 48, 64, 128, 160], + "stackwise_output_filters": [24, 48, 64, 128, 160, 256], + "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6], + "stackwise_squeeze_and_excite_ratios": [ + 0.0, + 0.0, + 0, + 0.25, + 0.25, + 0.25, + ], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "include_rescaling": False, + } + self.input_data = keras.ops.ones(shape=(8, 224, 224, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=EfficientNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + run_mixed_precision_check=False, + expected_output_shape=(8, 7, 7, 1280), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=EfficientNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_valid_call(self): + model = EfficientNetBackbone(**self.init_kwargs) + model(self.input_data) + + def test_valid_call_original_v1(self): + original_v1_kwargs = { + "stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3], + "stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1], + "stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192], + "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], + "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], + "stackwise_squeeze_and_excite_ratios": [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + ], + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "include_rescaling": False, + "stackwise_block_types": ["v1"] * 7, + "min_depth": None, + "include_initial_padding": True, + "use_depth_divisor_as_min_depth": True, + "cap_round_filter_decrease": True, + "stem_conv_padding": "valid", + "batch_norm_momentum": 0.99, + } + model = EfficientNetBackbone(**original_v1_kwargs) + model(self.input_data) + + def test_valid_call_with_rescaling(self): + test_kwargs = self.init_kwargs.copy() + test_kwargs["include_rescaling"] = True + model = EfficientNetBackbone(**test_kwargs) + model(self.input_data) + + def test_feature_pyramid_outputs(self): + backbone = EfficientNetBackbone(**self.init_kwargs) + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + batch_size = 8 + height = width = 256 + outputs = model(keras.ops.ones(shape=(batch_size, height, width, 3))) + levels = ["P1", "P2", "P3", "P4", "P5"] + self.assertEquals(list(outputs.keys()), levels) + self.assertEquals( + outputs["P1"].shape, + (batch_size, height // 2**1, width // 2**1, 24), + ) + self.assertEquals( + outputs["P2"].shape, + (batch_size, height // 2**2, width // 2**2, 48), + ) + self.assertEquals( + outputs["P3"].shape, + (batch_size, height // 2**3, width // 2**3, 64), + ) + self.assertEquals( + outputs["P4"].shape, + (batch_size, height // 2**4, width // 2**4, 160), + ) + self.assertEquals( + outputs["P5"].shape, + (batch_size, height // 2**5, width // 2**5, 1280), + ) + + @parameterized.named_parameters( + ("one_channel", 1), + ("four_channels", 4), + ) + def test_application_variable_input_channels(self, num_channels): + test_kwargs = self.init_kwargs.copy() + test_kwargs["input_shape"] = (None, None, num_channels) + model = EfficientNetBackbone(**test_kwargs) + self.assertEqual(model.output_shape, (None, None, None, 1280)) diff --git a/keras_nlp/src/models/efficientnet/fusedmbconv.py b/keras_nlp/src/models/efficientnet/fusedmbconv.py new file mode 100644 index 0000000000..5c3817c30e --- /dev/null +++ b/keras_nlp/src/models/efficientnet/fusedmbconv.py @@ -0,0 +1,229 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +BN_AXIS = 3 + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + + +class FusedMBConvBlock(keras.layers.Layer): + """Implementation of the FusedMBConv block + + Also known as a Fused Mobile Inverted Residual Bottleneck block from: + [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML] + (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + [EfficientNetV2: Smaller Models and Faster Training] + (https://arxiv.org/abs/2104.00298v3). + + FusedMBConv blocks are based on MBConv blocks, and replace the depthwise and + 1x1 output convolution blocks with a single 3x3 convolution block, fusing + them together - hence the name "FusedMBConv". Alongside MBConv blocks, they + can be used in mobile-oriented and efficient architectures, and are present + in architectures EfficientNet. + + FusedMBConv blocks follow a narrow-wide-narrow structure - expanding a 1x1 + convolution, performing Squeeze-Excitation and then applying a 3x3 + convolution, which is a more efficient operation than conventional + wide-narrow-wide structures. + + As they're frequently used for models to be deployed to edge devices, + they're implemented as a layer for ease of use and re-use. + + Args: + input_filters: int, the number of input filters + output_filters: int, the number of output filters + expand_ratio: default 1, the ratio by which input_filters are multiplied + to expand the structure in the middle expansion phase + kernel_size: default 3, the kernel_size to apply to the expansion phase + convolutions + strides: default 1, the strides to apply to the expansion phase + convolutions + se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase, + and are chosen as the maximum between 1 and input_filters*se_ratio + batch_norm_momentum: default 0.9, the BatchNormalization momentum + activation: default "swish", the activation function used between + convolution operations + dropout: float, the optional dropout rate to apply before the output + convolution, defaults to 0.2 + + Returns: + A tensor representing a feature map, passed through the FusedMBConv + block + + Note: + Not intended to be used outside of the EfficientNet architecture. + """ + + def __init__( + self, + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + batch_norm_momentum=0.9, + activation="swish", + dropout=0.2, + **kwargs + ): + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.se_ratio = se_ratio + self.batch_norm_momentum = batch_norm_momentum + self.activation = activation + self.dropout = dropout + self.filters = self.input_filters * self.expand_ratio + self.filters_se = max(1, int(input_filters * se_ratio)) + + self.conv1 = keras.layers.Conv2D( + filters=self.filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "expand_conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "expand_bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "expand_activation" + ) + + self.bn2 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "bn", + ) + + self.se_conv1 = keras.layers.Conv2D( + self.filters_se, + 1, + padding="same", + activation=self.activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_reduce", + ) + + self.se_conv2 = keras.layers.Conv2D( + self.filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_expand", + ) + + self.output_conv = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=1 if expand_ratio != 1 else kernel_size, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "project_conv", + ) + + self.bn3 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "project_bn", + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + # Expansion phase + if self.expand_ratio != 1: + x = self.conv1(inputs) + x = self.bn1(x) + x = self.act(x) + else: + x = inputs + + # Squeeze and excite + if 0 < self.se_ratio <= 1: + se = keras.layers.GlobalAveragePooling2D( + name=self.name + "se_squeeze" + )(x) + if BN_AXIS == 1: + se_shape = (self.filters, 1, 1) + else: + se_shape = (1, 1, self.filters) + + se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( + se + ) + + se = self.se_conv1(se) + se = self.se_conv2(se) + + x = keras.layers.multiply([x, se], name=self.name + "se_excite") + + # Output phase: + x = self.output_conv(x) + x = self.bn3(x) + if self.expand_ratio == 1: + x = self.act(x) + + # Residual: + if self.strides == 1 and self.input_filters == self.output_filters: + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "expand_ratio": self.expand_ratio, + "kernel_size": self.kernel_size, + "strides": self.strides, + "se_ratio": self.se_ratio, + "batch_norm_momentum": self.batch_norm_momentum, + "activation": self.activation, + "dropout": self.dropout, + } + + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_nlp/src/models/efficientnet/fusedmbconv_test.py b/keras_nlp/src/models/efficientnet/fusedmbconv_test.py new file mode 100644 index 0000000000..e59e251156 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/fusedmbconv_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_nlp.src.models.efficientnet.fusedmbconv import FusedMBConvBlock +from keras_nlp.src.tests.test_case import TestCase + + +class FusedMBConvBlockTest(TestCase): + def test_same_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock(input_filters=32, output_filters=32) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 32)) + self.assertLen(output, 1) + + def test_different_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock(input_filters=32, output_filters=48) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) + + def test_squeeze_excitation_ratio(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock( + input_filters=32, output_filters=48, se_ratio=0.25 + ) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) diff --git a/keras_nlp/src/models/efficientnet/mbconv.py b/keras_nlp/src/models/efficientnet/mbconv.py new file mode 100644 index 0000000000..4889606f8f --- /dev/null +++ b/keras_nlp/src/models/efficientnet/mbconv.py @@ -0,0 +1,238 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +BN_AXIS = 3 + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + + +class MBConvBlock(keras.layers.Layer): + def __init__( + self, + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + batch_norm_momentum=0.9, + activation="swish", + dropout=0.2, + **kwargs + ): + """Implementation of the MBConv block + + Also known as a Mobile Inverted Residual Bottleneck block from: + [MobileNetV2: Inverted Residuals and Linear Bottlenecks] + (https://arxiv.org/abs/1801.04381v4). + + MBConv blocks are common blocks used in mobile-oriented and efficient + architectures, present in architectures such as MobileNet, EfficientNet, + MaxViT, etc. + + MBConv blocks follow a narrow-wide-narrow structure - expanding a 1x1 + convolution, applying depthwise convolution, and narrowing back to a 1x1 + convolution, which is a more efficient operation than conventional + wide-narrow-wide structures. + + As they're frequently used for models to be deployed to edge devices, + they're implemented as a layer for ease of use and re-use. + + Args: + input_filters: int, the number of input filters + output_filters: int, the optional number of output filters after + Squeeze-Excitation + expand_ratio: default 1, the ratio by which input_filters are + multiplied to expand the structure in the middle expansion phase + kernel_size: default 3, the kernel_size to apply to the expansion + phase convolutions + strides: default 1, the strides to apply to the expansion phase + convolutions + se_ratio: default 0.0, Squeeze-Excitation happens before depthwise + convolution and before output convolution only if the se_ratio + is above 0. The filters used in this phase are chosen as the + maximum between 1 and input_filters*se_ratio + batch_norm_momentum: default 0.9, the BatchNormalization momentum + activation: default "swish", the activation function used between + convolution operations + dropout: float, the optional dropout rate to apply before the output + convolution, defaults to 0.2 + + Returns: + A tensor representing a feature map, passed through the MBConv + block + + + Note: + Not intended to be used outside of the EfficientNet architecture. + """ + + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.se_ratio = se_ratio + self.batch_norm_momentum = batch_norm_momentum + self.activation = activation + self.dropout = dropout + self.filters = self.input_filters * self.expand_ratio + self.filters_se = max(1, int(input_filters * se_ratio)) + + self.conv1 = keras.layers.Conv2D( + filters=self.filters, + kernel_size=1, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "expand_conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "expand_bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "activation" + ) + self.depthwise = keras.layers.DepthwiseConv2D( + kernel_size=self.kernel_size, + strides=self.strides, + depthwise_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "dwconv2", + ) + + self.bn2 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "bn", + ) + + self.se_conv1 = keras.layers.Conv2D( + self.filters_se, + 1, + padding="same", + activation=self.activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_reduce", + ) + + self.se_conv2 = keras.layers.Conv2D( + self.filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_expand", + ) + + self.output_conv = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=1 if expand_ratio != 1 else kernel_size, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "project_conv", + ) + + self.bn3 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "project_bn", + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + # Expansion phase + if self.expand_ratio != 1: + x = self.conv1(inputs) + x = self.bn1(x) + x = self.act(x) + else: + x = inputs + + # Depthwise conv + x = self.depthwise(x) + x = self.bn2(x) + x = self.act(x) + + # Squeeze and excite + if 0 < self.se_ratio <= 1: + se = keras.layers.GlobalAveragePooling2D( + name=self.name + "se_squeeze" + )(x) + if BN_AXIS == 1: + se_shape = (self.filters, 1, 1) + else: + se_shape = (1, 1, self.filters) + se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( + se + ) + + se = self.se_conv1(se) + se = self.se_conv2(se) + + x = keras.layers.multiply([x, se], name=self.name + "se_excite") + + # Output phase + x = self.output_conv(x) + x = self.bn3(x) + + if self.strides == 1 and self.input_filters == self.output_filters: + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "expand_ratio": self.expand_ratio, + "kernel_size": self.kernel_size, + "strides": self.strides, + "se_ratio": self.se_ratio, + "batch_norm_momentum": self.batch_norm_momentum, + "activation": self.activation, + "dropout": self.dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_nlp/src/models/efficientnet/mbconv_test.py b/keras_nlp/src/models/efficientnet/mbconv_test.py new file mode 100644 index 0000000000..d4ba2b1f73 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/mbconv_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_nlp.src.models.efficientnet.mbconv import MBConvBlock +from keras_nlp.src.tests.test_case import TestCase + + +class MBConvTest(TestCase): + def test_same_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=32) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 32)) + self.assertLen(output, 1) + + def test_different_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=48) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) + + def test_squeeze_excitation_ratio(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=48, se_ratio=0.25) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) From be8888d4841be85da7fd35d06fbdf5a9538ee39a Mon Sep 17 00:00:00 2001 From: gowthamkpr <47574994+gowthamkpr@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:02:52 -0700 Subject: [PATCH 22/33] Add the ResNet_vd backbone (#1766) * Add ResNet_vd to ResNet backbone * Addressed requested parameter changes * Fixed tests and updated comments * Added new parameters to docstring --- .../src/models/resnet/resnet_backbone.py | 413 ++++++++++++++++-- .../src/models/resnet/resnet_backbone_test.py | 29 +- .../resnet/resnet_image_classifier_test.py | 2 + 3 files changed, 416 insertions(+), 28 deletions(-) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 31698e0a1c..ca1de9b090 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone): This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( CVPR 2016), [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476)( - NeurIPS 2021 Workshop). + NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with + Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone): the batch normalization and ReLU activation are applied after the convolution layers. + ResNetVd introduces two key modifications to the standard ResNet. First, + the initial convolutional layer is replaced by a series of three + successive convolutional layers. Second, shortcut connections use an + additional pooling operation rather than performing downsampling within + the convolutional layers themselves. + Note that `ResNetBackbone` expects the inputs to be images with a value range of `[0, 255]` when `include_rescaling=True`. Args: + input_conv_filters: list of ints. The number of filters of the initial + convolution(s). + input_conv_kernel_sizes: list of ints. The kernel sizes of the initial + convolution(s). stackwise_num_filters: list of ints. The number of filters for each stack. stackwise_num_blocks: list of ints. The number of blocks for each stack. stackwise_num_strides: list of ints. The number of strides for each stack. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using @@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone): # Randomly initialized ResNetV2 backbone with a custom config. model = keras_nlp.models.ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], @@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone): def __init__( self, + input_conv_filters, + input_conv_kernel_sizes, stackwise_num_filters, stackwise_num_blocks, stackwise_num_strides, @@ -113,6 +131,13 @@ def __init__( dtype=None, **kwargs, ): + if len(input_conv_filters) != len(input_conv_kernel_sizes): + raise ValueError( + "The length of `input_conv_filters` and" + "`input_conv_kernel_sizes` must be the same. " + f"Received: input_conv_filters={input_conv_filters}, " + f"input_conv_kernel_sizes={input_conv_kernel_sizes}." + ) if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( stackwise_num_filters ) != len(stackwise_num_strides): @@ -128,14 +153,20 @@ def __init__( "The first element of `stackwise_num_filters` must be 64. " f"Received: stackwise_num_filters={stackwise_num_filters}" ) - if block_type not in ("basic_block", "bottleneck_block"): + if block_type not in ( + "basic_block", + "bottleneck_block", + "basic_block_vd", + "bottleneck_block_vd", + ): raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) - version = "v1" if not use_pre_activation else "v2" data_format = standardize_data_format(data_format) bn_axis = -1 if data_format == "channels_last" else 1 + num_input_convs = len(input_conv_filters) num_stacks = len(stackwise_num_filters) # === Functional Model === @@ -155,29 +186,56 @@ def __init__( # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( - 3, + (input_conv_kernel_sizes[0] - 1) // 2, data_format=data_format, dtype=dtype, name="conv1_pad", )(x) x = layers.Conv2D( - 64, - 7, + input_conv_filters[0], + input_conv_kernel_sizes[0], strides=2, data_format=data_format, use_bias=False, + padding="valid", dtype=dtype, name="conv1_conv", )(x) + for conv_index in range(1, num_input_convs): + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"conv{conv_index}_bn", + )(x) + x = layers.Activation( + "relu", dtype=dtype, name=f"conv{conv_index}_relu" + )(x) + x = layers.Conv2D( + input_conv_filters[conv_index], + input_conv_kernel_sizes[conv_index], + strides=1, + data_format=data_format, + use_bias=False, + padding="same", + dtype=dtype, + name=f"conv{conv_index+1}_conv", + )(x) + if not use_pre_activation: x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, momentum=0.9, dtype=dtype, - name="conv1_bn", + name=f"conv{num_input_convs}_bn", + )(x) + x = layers.Activation( + "relu", + dtype=dtype, + name=f"conv{num_input_convs}_relu", )(x) - x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) if use_pre_activation: # A workaround for ResNetV2: we need -inf padding to prevent zeros @@ -210,12 +268,10 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, - first_shortcut=( - block_type == "bottleneck_block" or stack_index > 0 - ), + first_shortcut=(block_type != "basic_block" or stack_index > 0), data_format=data_format, dtype=dtype, - name=f"{version}_stack{stack_index}", + name=f"stack{stack_index}", ) pyramid_outputs[f"P{stack_index + 2}"] = x @@ -248,6 +304,8 @@ def __init__( ) # === Config === + self.input_conv_filters = input_conv_filters + self.input_conv_kernel_sizes = input_conv_kernel_sizes self.stackwise_num_filters = stackwise_num_filters self.stackwise_num_blocks = stackwise_num_blocks self.stackwise_num_strides = stackwise_num_strides @@ -262,6 +320,8 @@ def get_config(self): config = super().get_config() config.update( { + "input_conv_filters": self.input_conv_filters, + "input_conv_kernel_sizes": self.input_conv_kernel_sizes, "stackwise_num_filters": self.stackwise_num_filters, "stackwise_num_blocks": self.stackwise_num_blocks, "stackwise_num_strides": self.stackwise_num_strides, @@ -327,7 +387,10 @@ def apply_basic_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( filters, 1, @@ -336,7 +399,7 @@ def apply_basic_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x) + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -452,7 +515,10 @@ def apply_bottleneck_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( 4 * filters, 1, @@ -461,7 +527,295 @@ def apply_bottleneck_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + x = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_basic_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -561,8 +915,11 @@ def apply_stack( blocks: int. The number of blocks in the stack. stride: int. The stride length of the first layer in the first block. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, @@ -580,17 +937,21 @@ def apply_stack( Output tensor for the stacked blocks. """ if name is None: - version = "v1" if not use_pre_activation else "v2" - name = f"{version}_stack" + name = "stack" if block_type == "basic_block": block_fn = apply_basic_block elif block_type == "bottleneck_block": block_fn = apply_bottleneck_block + elif block_type == "basic_block_vd": + block_fn = apply_basic_block_vd + elif block_type == "bottleneck_block_vd": + block_fn = apply_bottleneck_block_vd else: raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) for i in range(blocks): if i == 0: diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index a6a30362cd..f52800801f 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -24,6 +24,8 @@ class ResNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], @@ -38,18 +40,32 @@ def setUp(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) def test_backbone_basics(self, use_pre_activation, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( - {"block_type": block_type, "use_pre_activation": use_pre_activation} + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_vision_backbone_test( cls=ResNetBackbone, init_kwargs=init_kwargs, input_data=self.input_data, expected_output_shape=( - (2, 64) if block_type == "basic_block" else (2, 256) + (2, 64) + if block_type in ("basic_block", "basic_block_vd") + else (2, 256) ), ) @@ -76,6 +92,8 @@ def test_pyramid_output_format(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) @pytest.mark.large def test_saved_model(self, use_pre_activation, block_type): @@ -87,6 +105,13 @@ def test_saved_model(self, use_pre_activation, block_type): "image_shape": (None, None, 3), } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_model_saving_test( cls=ResNetBackbone, init_kwargs=init_kwargs, diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index 893ec42487..da06c80320 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -26,6 +26,8 @@ def setUp(self): self.images = ops.ones((2, 16, 16, 3)) self.labels = [0, 3] self.backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], From 536474a0ab2d55365ddf2e5faaf5968f7e70a767 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Thu, 29 Aug 2024 07:05:59 +0800 Subject: [PATCH 23/33] Add `VAEImageDecoder` for StableDiffusionV3 (#1796) * Add `VAEImageDecoder` for StableDiffusionV3 * Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention` --- .../stable_diffusion_v3/vae_attention.py | 126 +++++++++++++ .../stable_diffusion_v3/vae_image_decoder.py | 177 ++++++++++++++++++ 2 files changed, 303 insertions(+) create mode 100644 keras_nlp/src/models/stable_diffusion_v3/vae_attention.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py diff --git a/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py b/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py new file mode 100644 index 0000000000..1fba90d681 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py @@ -0,0 +1,126 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras import layers +from keras import ops + +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +class VAEAttention(layers.Layer): + def __init__(self, filters, groups=32, data_format=None, **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.data_format = standardize_data_format(data_format) + gn_axis = -1 if self.data_format == "channels_last" else 1 + + self.group_norm = layers.GroupNormalization( + groups=groups, + axis=gn_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="group_norm", + ) + self.query_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="query_conv2d", + ) + self.key_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="key_conv2d", + ) + self.value_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="value_conv2d", + ) + self.softmax = layers.Softmax(dtype="float32") + self.output_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="output_conv2d", + ) + + self.groups = groups + self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) + + def build(self, input_shape): + self.group_norm.build(input_shape) + self.query_conv2d.build(input_shape) + self.key_conv2d.build(input_shape) + self.value_conv2d.build(input_shape) + self.output_conv2d.build(input_shape) + + def call(self, inputs, training=None): + x = self.group_norm(inputs) + query = self.query_conv2d(x) + key = self.key_conv2d(x) + value = self.value_conv2d(x) + + if self.data_format == "channels_first": + query = ops.transpose(query, (0, 2, 3, 1)) + key = ops.transpose(key, (0, 2, 3, 1)) + value = ops.transpose(value, (0, 2, 3, 1)) + shape = ops.shape(inputs) + b = shape[0] + query = ops.reshape(query, (b, -1, self.filters)) + key = ops.reshape(key, (b, -1, self.filters)) + value = ops.reshape(value, (b, -1, self.filters)) + + # Compute attention. + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_filters, query.dtype) + ) + # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] + attention_scores = ops.einsum("abc,adc->abd", query, key) + attention_scores = ops.cast( + self.softmax(attention_scores), self.compute_dtype + ) + # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] + attention_output = ops.einsum("abc,adb->adc", value, attention_scores) + x = ops.reshape(attention_output, shape) + + x = self.output_conv2d(x) + if self.data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.add(x, inputs) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "groups": self.groups, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py new file mode 100644 index 0000000000..3f058addb7 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py @@ -0,0 +1,177 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.models.stable_diffusion_v3.vae_attention import VAEAttention +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +class VAEImageDecoder(keras.Model): + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + output_channels=3, + latent_shape=(None, None, 16), + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + + # === Functional Model === + latent_inputs = layers.Input(shape=latent_shape) + + x = layers.Conv2D( + stackwise_num_filters[0], + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="input_projection", + )(latent_inputs) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block0", + ) + x = VAEAttention( + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_attention", + )(x) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block1", + ) + + # Stacks. + for i, filters in enumerate(stackwise_num_filters): + for j in range(stackwise_num_blocks[i]): + x = apply_resnet_block( + x, + filters, + data_format=data_format, + dtype=dtype, + name=f"block{i}_{j}", + ) + if i != len(stackwise_num_filters) - 1: + # No upsamling in the last blcok. + x = layers.UpSampling2D( + 2, + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}", + )(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}_conv", + )(x) + + # Ouput block. + x = layers.GroupNormalization( + groups=32, + axis=gn_axis, + epsilon=1e-6, + dtype=dtype, + name="output_norm", + )(x) + x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) + image_outputs = layers.Conv2D( + output_channels, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="output_projection", + )(x) + super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.output_channels = output_channels + self.latent_shape = latent_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "output_channels": self.output_channels, + "image_shape": self.latent_shape, + } + ) + return config + + +def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + input_filters = x.shape[gn_axis] + + residual = x + x = layers.GroupNormalization( + groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1" + )(x) + x = layers.Activation("swish", dtype=dtype)(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv1", + )(x) + x = layers.GroupNormalization( + groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2" + )(x) + x = layers.Activation("swish")(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv2", + )(x) + if input_filters != filters: + residual = layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=dtype, + name=f"{name}_residual_projection", + )(residual) + x = layers.Add(dtype=dtype)([residual, x]) + return x From 0fbd84bbedbe597d404803633c511ec48b54b755 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Thu, 29 Aug 2024 07:06:53 +0800 Subject: [PATCH 24/33] Replace `Backbone` with `keras.Model` in `CLIPTextEncoder` and `T5XXLTextEncoder` (#1802) --- .../stable_diffusion_v3/clip_text_encoder.py | 14 +++++++++++--- .../stable_diffusion_v3/t5_xxl_text_encoder.py | 13 ++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py index d4a5cbc94f..899ae665c7 100644 --- a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras from keras import layers from keras import ops from keras_nlp.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, ) -from keras_nlp.src.models.backbone import Backbone from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import ( CLIPEncoderBlock, ) -class CLIPTextEncoder(Backbone): +class CLIPTextEncoder(keras.Model): def __init__( self, embedding_dim, @@ -108,7 +108,6 @@ def __init__( super().__init__( inputs={"encoder_token_ids": encoder_token_ids}, outputs=outputs, - dtype=dtype, **kwargs, ) @@ -123,6 +122,15 @@ def __init__( self.vocabulary_size = vocabulary_size self.sequence_length = sequence_length + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py index 9f4e5ae3a1..5c44395489 100644 --- a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py @@ -16,12 +16,11 @@ from keras_nlp.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) -from keras_nlp.src.models.backbone import Backbone from keras_nlp.src.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.src.models.t5.t5_transformer_layer import T5TransformerLayer -class T5XXLTextEncoder(Backbone): +class T5XXLTextEncoder(keras.Model): def __init__( self, vocabulary_size, @@ -111,7 +110,6 @@ def __init__( "encoder_padding_mask": encoder_padding_mask_input, }, outputs=encoder_output, - dtype=dtype, **kwargs, ) @@ -128,6 +126,15 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.tie_embedding_weights = tie_embedding_weights + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + def get_config(self): config = super().get_config() config.update( From 91434680da6aad4e80f36f539f1fbbb3ad5bfc9a Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 3 Sep 2024 09:38:21 -0700 Subject: [PATCH 25/33] Add pyramid output for densenet, cspDarknet (#1801) * add pyramid outputs * fix testcase * format fix * make common testcase for pyramid outputs * change default shape * simplify testcase * test case change and add channel axis --- .../csp_darknet/csp_darknet_backbone.py | 69 +++++++++++++++---- .../csp_darknet/csp_darknet_backbone_test.py | 17 +++-- .../src/models/densenet/densenet_backbone.py | 68 ++++++++++++------ .../models/densenet/densenet_backbone_test.py | 19 +++-- .../mix_transformer_backbone.py | 4 +- .../mix_transformer_backbone_test.py | 21 ++---- .../src/models/resnet/resnet_backbone_test.py | 21 +----- keras_nlp/src/tests/test_case.py | 21 ++++++ 8 files changed, 152 insertions(+), 88 deletions(-) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py index 607c6895ba..40efb6de04 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -15,11 +15,11 @@ from keras import layers from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone @keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") -class CSPDarkNetBackbone(Backbone): +class CSPDarkNetBackbone(FeaturePyramidBackbone): """This class represents Keras Backbone of CSPDarkNet model. This class implements a CSPDarkNet backbone as described in @@ -65,12 +65,15 @@ def __init__( self, stackwise_num_filters, stackwise_depth, - include_rescaling, + include_rescaling=True, block_type="basic_block", - image_shape=(224, 224, 3), + image_shape=(None, None, 3), **kwargs, ): # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) apply_ConvBlock = ( apply_darknet_conv_block_depthwise if block_type == "depthwise_block" @@ -83,15 +86,22 @@ def __init__( if include_rescaling: x = layers.Rescaling(scale=1 / 255.0)(x) - x = apply_focus(name="stem_focus")(x) + x = apply_focus(channel_axis, name="stem_focus")(x) x = apply_darknet_conv_block( - base_channels, kernel_size=3, strides=1, name="stem_conv" + base_channels, + channel_axis, + kernel_size=3, + strides=1, + name="stem_conv", )(x) + + pyramid_outputs = {} for index, (channels, depth) in enumerate( zip(stackwise_num_filters, stackwise_depth) ): x = apply_ConvBlock( channels, + channel_axis, kernel_size=3, strides=2, name=f"dark{index + 2}_conv", @@ -100,17 +110,20 @@ def __init__( if index == len(stackwise_depth) - 1: x = apply_spatial_pyramid_pooling_bottleneck( channels, + channel_axis, hidden_filters=channels // 2, name=f"dark{index + 2}_spp", )(x) x = apply_cross_stage_partial( channels, + channel_axis, num_bottlenecks=depth, block_type="basic_block", residual=(index != len(stackwise_depth) - 1), name=f"dark{index + 2}_csp", )(x) + pyramid_outputs[f"P{index + 2}"] = x super().__init__(inputs=image_input, outputs=x, **kwargs) @@ -120,6 +133,7 @@ def __init__( self.include_rescaling = include_rescaling self.block_type = block_type self.image_shape = image_shape + self.pyramid_outputs = pyramid_outputs def get_config(self): config = super().get_config() @@ -135,7 +149,7 @@ def get_config(self): return config -def apply_focus(name=None): +def apply_focus(channel_axis, name=None): """A block used in CSPDarknet to focus information into channels of the image. @@ -151,7 +165,7 @@ def apply_focus(name=None): """ def apply(x): - return layers.Concatenate(name=name)( + return layers.Concatenate(axis=channel_axis, name=name)( [ x[..., ::2, ::2, :], x[..., 1::2, ::2, :], @@ -164,7 +178,13 @@ def apply(x): def apply_darknet_conv_block( - filters, kernel_size, strides, use_bias=False, activation="silu", name=None + filters, + channel_axis, + kernel_size, + strides, + use_bias=False, + activation="silu", + name=None, ): """ The basic conv block used in Darknet. Applies Conv2D followed by a @@ -193,11 +213,12 @@ def apply(inputs): kernel_size, strides, padding="same", + data_format=keras.config.image_data_format(), use_bias=use_bias, name=name + "_conv", )(inputs) - x = layers.BatchNormalization(name=name + "_bn")(x) + x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x) if activation == "silu": x = layers.Lambda(lambda x: keras.activations.silu(x))(x) @@ -212,7 +233,7 @@ def apply(inputs): def apply_darknet_conv_block_depthwise( - filters, kernel_size, strides, activation="silu", name=None + filters, channel_axis, kernel_size, strides, activation="silu", name=None ): """ The depthwise conv block used in CSPDarknet. @@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise( def apply(inputs): x = layers.DepthwiseConv2D( - kernel_size, strides, padding="same", use_bias=False + kernel_size, + strides, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, )(inputs) - x = layers.BatchNormalization()(x) + x = layers.BatchNormalization(axis=channel_axis)(x) if activation == "silu": x = layers.Lambda(lambda x: keras.activations.swish(x))(x) @@ -248,7 +273,11 @@ def apply(inputs): x = layers.LeakyReLU(0.1)(x) x = apply_darknet_conv_block( - filters, kernel_size=1, strides=1, activation=activation + filters, + channel_axis, + kernel_size=1, + strides=1, + activation=activation, )(x) return x @@ -258,6 +287,7 @@ def apply(inputs): def apply_spatial_pyramid_pooling_bottleneck( filters, + channel_axis, hidden_filters=None, kernel_sizes=(5, 9, 13), activation="silu", @@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck( def apply(x): x = apply_darknet_conv_block( hidden_filters, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -304,13 +335,15 @@ def apply(x): kernel_size, strides=1, padding="same", + data_format=keras.config.image_data_format(), name=f"{name}_maxpool_{kernel_size}", )(x[0]) ) - x = layers.Concatenate(name=f"{name}_concat")(x) + x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x) x = apply_darknet_conv_block( filters, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -324,6 +357,7 @@ def apply(x): def apply_cross_stage_partial( filters, + channel_axis, num_bottlenecks, residual=True, block_type="basic_block", @@ -361,6 +395,7 @@ def apply(inputs): x1 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -369,6 +404,7 @@ def apply(inputs): x2 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -379,6 +415,7 @@ def apply(inputs): residual_x = x1 x1 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -386,6 +423,7 @@ def apply(inputs): )(x1) x1 = ConvBlock( hidden_channels, + channel_axis, kernel_size=3, strides=1, activation=activation, @@ -399,6 +437,7 @@ def apply(inputs): x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) x = apply_darknet_conv_block( filters, + channel_axis, kernel_size=1, strides=1, activation=activation, diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py index 857e06039d..ed6dc7b525 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -24,21 +24,26 @@ class CSPDarkNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_num_filters": [2, 4, 6, 8], "stackwise_depth": [1, 3, 3, 1], - "include_rescaling": False, "block_type": "basic_block", - "image_shape": (224, 224, 3), + "image_shape": (32, 32, 3), } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=CSPDarkNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 7, 7, 256), + expected_output_shape=(2, 1, 1, 8), + expected_pyramid_output_keys=["P2", "P3", "P4", "P5"], + expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)], run_mixed_precision_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index 60a5b28849..13e3d8597f 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -14,14 +14,13 @@ import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -BN_AXIS = 3 BN_EPSILON = 1.001e-5 @keras_nlp_export("keras_nlp.models.DenseNetBackbone") -class DenseNetBackbone(Backbone): +class DenseNetBackbone(FeaturePyramidBackbone): """Instantiates the DenseNet architecture. This class implements a DenseNet backbone as described in @@ -35,7 +34,7 @@ class DenseNetBackbone(Backbone): include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (None, None, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. growth_rate: int, number of filters added by each dense block, @@ -62,12 +61,14 @@ def __init__( self, stackwise_num_repeats, include_rescaling=True, - image_shape=(224, 224, 3), + image_shape=(None, None, 3), compression_ratio=0.5, growth_rate=32, **kwargs, ): # === Functional Model === + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 image_input = keras.layers.Input(shape=image_shape) x = image_input @@ -75,37 +76,47 @@ def __init__( x = keras.layers.Rescaling(1 / 255.0)(x) x = keras.layers.Conv2D( - 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + 64, + 7, + strides=2, + use_bias=False, + padding="same", + data_format=data_format, + name="conv1_conv", )(x) x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + axis=channel_axis, epsilon=BN_EPSILON, name="conv1_bn" )(x) x = keras.layers.Activation("relu", name="conv1_relu")(x) x = keras.layers.MaxPooling2D( - 3, strides=2, padding="same", name="pool1" + 3, strides=2, padding="same", data_format=data_format, name="pool1" )(x) + pyramid_outputs = {} for stack_index in range(len(stackwise_num_repeats) - 1): index = stack_index + 2 x = apply_dense_block( x, + channel_axis, stackwise_num_repeats[stack_index], growth_rate, name=f"conv{index}", ) + pyramid_outputs[f"P{index}"] = x x = apply_transition_block( - x, compression_ratio, name=f"pool{index}" + x, channel_axis, compression_ratio, name=f"pool{index}" ) x = apply_dense_block( x, + channel_axis, stackwise_num_repeats[-1], growth_rate, name=f"conv{len(stackwise_num_repeats) + 1}", ) - + pyramid_outputs[f"P{len(stackwise_num_repeats) +1}"] = x x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + axis=channel_axis, epsilon=BN_EPSILON, name="bn" )(x) x = keras.layers.Activation("relu", name="relu")(x) @@ -117,6 +128,7 @@ def __init__( self.compression_ratio = compression_ratio self.growth_rate = growth_rate self.image_shape = image_shape + self.pyramid_outputs = pyramid_outputs def get_config(self): config = super().get_config() @@ -132,7 +144,7 @@ def get_config(self): return config -def apply_dense_block(x, num_repeats, growth_rate, name=None): +def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None): """A dense block. Args: @@ -145,11 +157,13 @@ def apply_dense_block(x, num_repeats, growth_rate, name=None): name = f"dense_block_{keras.backend.get_uid('dense_block')}" for i in range(num_repeats): - x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + x = apply_conv_block( + x, channel_axis, growth_rate, name=f"{name}_block_{i}" + ) return x -def apply_transition_block(x, compression_ratio, name=None): +def apply_transition_block(x, channel_axis, compression_ratio, name=None): """A transition block. Args: @@ -157,24 +171,28 @@ def apply_transition_block(x, compression_ratio, name=None): compression_ratio: float, compression rate at transition layers. name: string, block label. """ + data_format = keras.config.image_data_format() if name is None: name = f"transition_block_{keras.backend.get_uid('transition_block')}" x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_relu")(x) x = keras.layers.Conv2D( - int(x.shape[BN_AXIS] * compression_ratio), + int(x.shape[channel_axis] * compression_ratio), 1, use_bias=False, + data_format=data_format, name=f"{name}_conv", )(x) - x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + x = keras.layers.AveragePooling2D( + 2, strides=2, data_format=data_format, name=f"{name}_pool" + )(x) return x -def apply_conv_block(x, growth_rate, name=None): +def apply_conv_block(x, channel_axis, growth_rate, name=None): """A building block for a dense block. Args: @@ -182,19 +200,24 @@ def apply_conv_block(x, growth_rate, name=None): growth_rate: int, number of filters added by each dense block. name: string, block label. """ + data_format = keras.config.image_data_format() if name is None: name = f"conv_block_{keras.backend.get_uid('conv_block')}" shortcut = x x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_0_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) x = keras.layers.Conv2D( - 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + 4 * growth_rate, + 1, + use_bias=False, + data_format=data_format, + name=f"{name}_1_conv", )(x) x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_1_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) x = keras.layers.Conv2D( @@ -202,9 +225,10 @@ def apply_conv_block(x, growth_rate, name=None): 3, padding="same", use_bias=False, + data_format=data_format, name=f"{name}_2_conv", )(x) - x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + x = keras.layers.Concatenate(axis=channel_axis, name=f"{name}_concat")( [shortcut, x] ) return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py index 63f358035c..7720411319 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone_test.py +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -22,21 +22,26 @@ class DenseNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "stackwise_num_repeats": [6, 12, 24, 16], - "include_rescaling": True, + "stackwise_num_repeats": [2, 4, 6, 4], "compression_ratio": 0.5, - "growth_rate": 32, - "image_shape": (224, 224, 3), + "growth_rate": 2, + "image_shape": (32, 32, 3), } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=DenseNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 7, 7, 1024), + expected_output_shape=(2, 1, 1, 24), + expected_pyramid_output_keys=["P2", "P3", "P4", "P5"], + expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)], run_mixed_precision_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py index 2cfe7f6761..35c5f7fd5a 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -37,7 +37,7 @@ def __init__( patch_sizes, strides, include_rescaling=True, - image_shape=(224, 224, 3), + image_shape=(None, None, 3), hidden_dims=None, **kwargs, ): @@ -63,7 +63,7 @@ def __init__( include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (None, None, 3). hidden_dims: the embedding dims per hierarchical layer, used as the levels of the feature pyramid. patch_sizes: list of integers, the patch_size to apply for each layer. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index 4f1955297f..280adca065 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -14,7 +14,6 @@ import numpy as np import pytest -from keras import models from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( MiTBackbone, @@ -42,30 +41,18 @@ def setUp(self): ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, 2, 2, 8), + expected_pyramid_output_keys=["P1", "P2"], + expected_pyramid_image_sizes=[(4, 4), (2, 2)], run_quantization_check=False, run_mixed_precision_check=False, + run_data_format_check=False, ) - def test_pyramid_output_format(self): - init_kwargs = self.init_kwargs - backbone = MiTBackbone(**init_kwargs) - model = models.Model(backbone.inputs, backbone.pyramid_outputs) - output_data = model(self.input_data) - - self.assertIsInstance(output_data, dict) - self.assertEqual( - list(output_data.keys()), list(backbone.pyramid_outputs.keys()) - ) - self.assertEqual(list(output_data.keys()), ["P1", "P2"]) - for k, v in output_data.items(): - size = self.input_size // (2 ** (int(k[1:]) + 1)) - self.assertEqual(tuple(v.shape[:3]), (2, size, size)) - @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index f52800801f..33d4debac1 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,7 +14,6 @@ import pytest from absl.testing import parameterized -from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -67,26 +66,10 @@ def test_backbone_basics(self, use_pre_activation, block_type): if block_type in ("basic_block", "basic_block_vd") else (2, 256) ), + expected_pyramid_output_keys=["P2", "P3", "P4"], + expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)], ) - def test_pyramid_output_format(self): - init_kwargs = self.init_kwargs.copy() - init_kwargs.update( - {"block_type": "basic_block", "use_pre_activation": False} - ) - backbone = ResNetBackbone(**init_kwargs) - model = models.Model(backbone.inputs, backbone.pyramid_outputs) - output_data = model(self.input_data) - - self.assertIsInstance(output_data, dict) - self.assertEqual( - list(output_data.keys()), list(backbone.pyramid_outputs.keys()) - ) - self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) - for k, v in output_data.items(): - size = self.input_size // (2 ** int(k[1:])) - self.assertEqual(tuple(v.shape[:3]), (2, size, size)) - @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 8e63bc19d9..6dad75b84d 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -465,6 +465,8 @@ def run_vision_backbone_test( init_kwargs, input_data, expected_output_shape, + expected_pyramid_output_keys=None, + expected_pyramid_image_sizes=None, variable_length_data=None, run_mixed_precision_check=True, run_quantization_check=True, @@ -491,6 +493,25 @@ def run_vision_backbone_test( run_mixed_precision_check=run_mixed_precision_check, run_quantization_check=run_quantization_check, ) + if expected_pyramid_output_keys: + backbone = cls(**init_kwargs) + model = keras.models.Model( + backbone.inputs, backbone.pyramid_outputs + ) + output_data = model(input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual( + list(output_data.keys()), expected_pyramid_output_keys + ) + # check height and width of each level. + for i, (k, v) in enumerate(output_data.items()): + self.assertEqual( + tuple(v.shape[1:3]), expected_pyramid_image_sizes[i] + ) # Check data_format. We assume that `input_data` is in "channels_last" # format. From 791d7f615a7ff87d6c8f3b6f6ceaec86e82fe2d5 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Thu, 5 Sep 2024 01:15:40 +0800 Subject: [PATCH 26/33] Add `MMDiT` for StableDiffusionV3 (#1806) * Add `MMDiT` * Update * Update * Update implementation --- .../src/models/stable_diffusion_v3/mmdit.py | 427 ++++++++++++++++++ .../models/stable_diffusion_v3/mmdit_block.py | 317 +++++++++++++ .../stable_diffusion_v3/vae_image_decoder.py | 9 + 3 files changed, 753 insertions(+) create mode 100644 keras_nlp/src/models/stable_diffusion_v3/mmdit.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py diff --git a/keras_nlp/src/models/stable_diffusion_v3/mmdit.py b/keras_nlp/src/models/stable_diffusion_v3/mmdit.py new file mode 100644 index 0000000000..b26f0d04b3 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/mmdit.py @@ -0,0 +1,427 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras +from keras import layers +from keras import models +from keras import ops + +from keras_nlp.src.layers.modeling.position_embedding import PositionEmbedding +from keras_nlp.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +class PatchEmbedding(layers.Layer): + def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs): + super().__init__(**kwargs) + self.patch_size = int(patch_size) + self.hidden_dim = int(hidden_dim) + data_format = standardize_data_format(data_format) + + self.patch_embedding = layers.Conv2D( + hidden_dim, + kernel_size=patch_size, + strides=patch_size, + data_format=data_format, + dtype=self.dtype_policy, + name="patch_embedding", + ) + + def build(self, input_shape): + self.patch_embedding.build(input_shape) + + def call(self, inputs): + x = self.patch_embedding(inputs) + x_shape = ops.shape(x) + x = ops.reshape(x, (x_shape[0], x_shape[1] * x_shape[2], x_shape[3])) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + } + ) + return config + + +class AdjustablePositionEmbedding(PositionEmbedding): + def __init__( + self, + height, + width, + initializer="glorot_uniform", + **kwargs, + ): + height = int(height) + width = int(width) + sequence_length = height * width + super().__init__(sequence_length, initializer, **kwargs) + self.height = height + self.width = width + + def call(self, inputs, height=None, width=None): + height = height or self.height + width = width or self.width + shape = ops.shape(inputs) + feature_length = shape[-1] + top = ops.floor_divide(self.height - height, 2) + left = ops.floor_divide(self.width - width, 2) + position_embedding = ops.convert_to_tensor(self.position_embeddings) + position_embedding = ops.reshape( + position_embedding, (self.height, self.width, feature_length) + ) + position_embedding = ops.slice( + position_embedding, + (top, left, 0), + (height, width, feature_length), + ) + position_embedding = ops.reshape( + position_embedding, (height * width, feature_length) + ) + position_embedding = ops.expand_dims(position_embedding, axis=0) + return position_embedding + + def compute_output_shape(self, input_shape): + return input_shape + + +class TimestepEmbedding(layers.Layer): + def __init__( + self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs + ): + super().__init__(**kwargs) + self.embedding_dim = int(embedding_dim) + self.frequency_dim = int(frequency_dim) + self.max_period = float(max_period) + self.half_frequency_dim = self.frequency_dim // 2 + + self.mlp = models.Sequential( + [ + layers.Dense( + embedding_dim, activation="silu", dtype=self.dtype_policy + ), + layers.Dense( + embedding_dim, activation=None, dtype=self.dtype_policy + ), + ], + name="mlp", + ) + + def build(self, inputs_shape): + embedding_shape = list(inputs_shape)[1:] + embedding_shape.append(self.frequency_dim) + self.mlp.build(embedding_shape) + + def _create_timestep_embedding(self, inputs): + compute_dtype = keras.backend.result_type(self.compute_dtype, "float32") + x = ops.cast(inputs, compute_dtype) + freqs = ops.exp( + ops.divide( + ops.multiply( + -math.log(self.max_period), + ops.arange(0, self.half_frequency_dim, dtype="float32"), + ), + self.half_frequency_dim, + ) + ) + freqs = ops.cast(freqs, compute_dtype) + x = ops.multiply(x, ops.expand_dims(freqs, axis=0)) + embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1) + if self.frequency_dim % 2 != 0: + embedding = ops.pad(embedding, [[0, 0], [0, 1]]) + return ops.cast(embedding, self.compute_dtype) + + def call(self, inputs, training=None): + timestep_embedding = self._create_timestep_embedding(inputs) + return self.mlp(timestep_embedding, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_dim": self.embedding_dim, + "max_period": self.max_period, + } + ) + return config + + def compute_output_shape(self, inputs_shape): + output_shape = list(inputs_shape)[1:] + output_shape.append(self.embedding_dim) + return output_shape + + +class OutputLayer(layers.Layer): + def __init__(self, hidden_dim, output_dim, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.output_dim = output_dim + num_modulation = 2 + + self.adaptive_norm_modulation = models.Sequential( + [ + layers.Activation("silu", dtype=self.dtype_policy), + layers.Dense( + num_modulation * hidden_dim, dtype=self.dtype_policy + ), + ], + name="adaptive_norm_modulation", + ) + self.norm = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype=self.dtype_policy, + name="norm", + ) + self.output_dense = layers.Dense( + output_dim, # patch_size ** 2 * input_channels + use_bias=True, + dtype=self.dtype_policy, + name="output_dense", + ) + + def build(self, inputs_shape, timestep_embedding_shape): + self.adaptive_norm_modulation.build(timestep_embedding_shape) + self.norm.build(inputs_shape) + self.output_dense.build(inputs_shape) + + def _modulate(self, inputs, shift, scale): + shift = ops.expand_dims(shift, axis=1) + scale = ops.expand_dims(scale, axis=1) + return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) + + def call(self, inputs, timestep_embedding, training=None): + x = inputs + modulation = self.adaptive_norm_modulation( + timestep_embedding, training=training + ) + modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim)) + shift, scale = ops.unstack(modulation, 2, axis=1) + x = self._modulate(self.norm(x), shift, scale) + x = self.output_dense(x, training=training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + } + ) + return config + + +class Unpatch(layers.Layer): + def __init__(self, patch_size, output_dim, **kwargs): + super().__init__(**kwargs) + self.patch_size = int(patch_size) + self.output_dim = int(output_dim) + + def call(self, inputs, height, width): + patch_size = self.patch_size + output_dim = self.output_dim + x = ops.reshape( + inputs, + (-1, height, width, patch_size, patch_size, output_dim), + ) + # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o) + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + return ops.reshape( + x, + (-1, height * patch_size, width * patch_size, output_dim), + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "output_dim": self.output_dim, + } + ) + return config + + def compute_output_shape(self, inputs_shape): + inputs_shape = list(inputs_shape) + return [inputs_shape[0], None, None, self.output_dim] + + +class MMDiT(keras.Model): + def __init__( + self, + patch_size, + num_heads, + hidden_dim, + depth, + position_size, + output_dim, + mlp_ratio=4.0, + latent_shape=(64, 64, 16), + context_shape=(1024, 4096), + pooled_projection_shape=(2048,), + data_format=None, + dtype=None, + **kwargs, + ): + if None in latent_shape: + raise ValueError( + "`latent_shape` must be fully specified. " + f"Received: latent_shape={latent_shape}" + ) + image_height = latent_shape[0] // patch_size + image_width = latent_shape[1] // patch_size + output_dim_in_final = patch_size**2 * output_dim + data_format = standardize_data_format(data_format) + if data_format != "channels_last": + raise NotImplementedError( + "Currently only 'channels_last' is supported." + ) + + # === Layers === + self.patch_embedding = PatchEmbedding( + patch_size, + hidden_dim, + data_format=data_format, + dtype=dtype, + name="patch_embedding", + ) + self.position_embedding_add = layers.Add( + dtype=dtype, name="position_embedding_add" + ) + self.position_embedding = AdjustablePositionEmbedding( + position_size, position_size, dtype=dtype, name="position_embedding" + ) + self.context_embedding = layers.Dense( + hidden_dim, + dtype=dtype, + name="context_embedding", + ) + self.vector_embedding = models.Sequential( + [ + layers.Dense(hidden_dim, activation="silu", dtype=dtype), + layers.Dense(hidden_dim, activation=None, dtype=dtype), + ], + name="vector_embedding", + ) + self.vector_embedding_add = layers.Add( + dtype=dtype, name="vector_embedding_add" + ) + self.timestep_embedding = TimestepEmbedding( + hidden_dim, dtype=dtype, name="timestep_embedding" + ) + self.joint_blocks = [ + MMDiTBlock( + num_heads, + hidden_dim, + mlp_ratio, + use_context_projection=not (i == depth - 1), + dtype=dtype, + name=f"joint_block_{i}", + ) + for i in range(depth) + ] + self.output_layer = OutputLayer( + hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer" + ) + self.unpatch = Unpatch( + patch_size, output_dim, dtype=dtype, name="unpatch" + ) + + # === Functional Model === + latent_inputs = layers.Input(shape=latent_shape, name="latent") + context_inputs = layers.Input(shape=context_shape, name="context") + pooled_projection_inputs = layers.Input( + shape=pooled_projection_shape, name="pooled_projection" + ) + timestep_inputs = layers.Input(shape=(1,), name="timestep") + + # Embeddings. + x = self.patch_embedding(latent_inputs) + position_embedding = self.position_embedding( + x, height=image_height, width=image_width + ) + x = self.position_embedding_add([x, position_embedding]) + context = self.context_embedding(context_inputs) + pooled_projection = self.vector_embedding(pooled_projection_inputs) + timestep_embedding = self.timestep_embedding(timestep_inputs) + timestep_embedding = self.vector_embedding_add( + [timestep_embedding, pooled_projection] + ) + + # Blocks. + for block in self.joint_blocks: + if block.use_context_projection: + x, context = block(x, context, timestep_embedding) + else: + x = block(x, context, timestep_embedding) + + # Output layer. + x = self.output_layer(x, timestep_embedding) + outputs = self.unpatch(x, height=image_height, width=image_width) + + super().__init__( + inputs={ + "latent": latent_inputs, + "context": context_inputs, + "pooled_projection": pooled_projection_inputs, + "timestep": timestep_inputs, + }, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.depth = depth + self.position_size = position_size + self.output_dim = output_dim + self.mlp_ratio = mlp_ratio + self.latent_shape = latent_shape + self.context_shape = context_shape + self.pooled_projection_shape = pooled_projection_shape + + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "depth": self.depth, + "position_size": self.position_size, + "output_dim": self.output_dim, + "mlp_ratio": self.mlp_ratio, + "latent_shape": self.latent_shape, + "context_shape": self.context_shape, + "pooled_projection_shape": self.pooled_projection_shape, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py b/keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py new file mode 100644 index 0000000000..d537e856ef --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py @@ -0,0 +1,317 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras import layers +from keras import models +from keras import ops + +from keras_nlp.src.utils.keras_utils import gelu_approximate + + +class DismantledBlock(layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_ratio=4.0, + use_projection=True, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_ratio = mlp_ratio + self.use_projection = use_projection + + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + num_modulations = 6 if use_projection else 2 + self.num_modulations = num_modulations + + self.adaptive_norm_modulation = models.Sequential( + [ + layers.Activation("silu", dtype=self.dtype_policy), + layers.Dense( + num_modulations * hidden_dim, dtype=self.dtype_policy + ), + ], + name="adaptive_norm_modulation", + ) + self.norm1 = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype=self.dtype_policy, + name="norm1", + ) + self.attention_qkv = layers.Dense( + hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" + ) + if use_projection: + self.attention_proj = layers.Dense( + hidden_dim, dtype=self.dtype_policy, name="attention_proj" + ) + self.norm2 = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype=self.dtype_policy, + name="norm2", + ) + self.mlp = models.Sequential( + [ + layers.Dense( + mlp_hidden_dim, + activation=gelu_approximate, + dtype=self.dtype_policy, + ), + layers.Dense( + hidden_dim, + dtype=self.dtype_policy, + ), + ], + name="mlp", + ) + + def build(self, inputs_shape, timestep_embedding): + self.adaptive_norm_modulation.build(timestep_embedding) + self.attention_qkv.build(inputs_shape) + self.norm1.build(inputs_shape) + if self.use_projection: + self.attention_proj.build(inputs_shape) + self.norm2.build(inputs_shape) + self.mlp.build(inputs_shape) + + def _modulate(self, inputs, shift, scale): + shift = ops.expand_dims(shift, axis=1) + scale = ops.expand_dims(scale, axis=1) + return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) + + def _compute_pre_attention(self, inputs, timestep_embedding, training=None): + batch_size = ops.shape(inputs)[0] + if self.use_projection: + modulation = self.adaptive_norm_modulation( + timestep_embedding, training=training + ) + modulation = ops.reshape( + modulation, (batch_size, 6, self.hidden_dim) + ) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = ops.unstack(modulation, 6, axis=1) + qkv = self.attention_qkv( + self._modulate(self.norm1(inputs), shift_msa, scale_msa), + training=training, + ) + qkv = ops.reshape( + qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) + ) + q, k, v = ops.unstack(qkv, 3, axis=2) + return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + modulation = self.adaptive_norm_modulation( + timestep_embedding, training=training + ) + modulation = ops.reshape( + modulation, (batch_size, 2, self.hidden_dim) + ) + shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) + qkv = self.attention_qkv( + self._modulate(self.norm1(inputs), shift_msa, scale_msa), + training=training, + ) + qkv = ops.reshape( + qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) + ) + q, k, v = ops.unstack(qkv, 3, axis=2) + return (q, k, v) + + def _compute_post_attention( + self, inputs, inputs_intermediates, training=None + ): + x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates + attn = self.attention_proj(inputs, training=training) + x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) + x = ops.add( + x, + ops.multiply( + ops.expand_dims(gate_mlp, axis=1), + self.mlp( + self._modulate(self.norm2(x), shift_mlp, scale_mlp), + training=training, + ), + ), + ) + return x + + def call( + self, + inputs, + timestep_embedding=None, + inputs_intermediates=None, + pre_attention=True, + training=None, + ): + if pre_attention: + return self._compute_pre_attention( + inputs, timestep_embedding, training=training + ) + else: + return self._compute_post_attention( + inputs, inputs_intermediates, training=training + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_ratio": self.mlp_ratio, + "use_projection": self.use_projection, + } + ) + return config + + +class MMDiTBlock(layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_ratio=4.0, + use_context_projection=True, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_ratio = mlp_ratio + self.use_context_projection = use_context_projection + + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) + self._dot_product_equation = "aecd,abcd->acbe" + self._combine_equation = "acbe,aecd->abcd" + + self.x_block = DismantledBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_ratio=mlp_ratio, + use_projection=True, + dtype=self.dtype_policy, + name="x_block", + ) + self.context_block = DismantledBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_ratio=mlp_ratio, + use_projection=use_context_projection, + dtype=self.dtype_policy, + name="context_block", + ) + + def build(self, inputs_shape, context_shape, timestep_embedding_shape): + self.x_block.build(inputs_shape, timestep_embedding_shape) + self.context_block.build(context_shape, timestep_embedding_shape) + + def _compute_attention(self, query, key, value): + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) + ) + attention_scores = ops.einsum(self._dot_product_equation, key, query) + attention_scores = ops.nn.softmax(attention_scores, axis=-1) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + batch_size = ops.shape(attention_output)[0] + attention_output = ops.reshape( + attention_output, (batch_size, -1, self.num_heads * self.head_dim) + ) + return attention_output + + def call(self, inputs, context, timestep_embedding, training=None): + # Compute pre-attention. + x = inputs + if self.use_context_projection: + context_qkv, context_intermediates = self.context_block( + context, + timestep_embedding=timestep_embedding, + training=training, + ) + else: + context_qkv = self.context_block( + context, + timestep_embedding=timestep_embedding, + training=training, + ) + context_len = ops.shape(context_qkv[0])[1] + x_qkv, x_intermediates = self.x_block( + x, timestep_embedding=timestep_embedding, training=training + ) + q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1) + k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1) + v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1) + + # Compute attention. + attention = self._compute_attention(q, k, v) + context_attention = attention[:, :context_len] + x_attention = attention[:, context_len:] + + # Compute post-attention. + x = self.x_block( + x_attention, + inputs_intermediates=x_intermediates, + pre_attention=False, + training=training, + ) + if self.use_context_projection: + context = self.context_block( + context_attention, + inputs_intermediates=context_intermediates, + pre_attention=False, + training=training, + ) + return x, context + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_ratio": self.mlp_ratio, + "use_context_projection": self.use_context_projection, + } + ) + return config + + def compute_output_shape( + self, inputs_shape, context_shape, timestep_embedding_shape + ): + if self.use_context_projection: + return inputs_shape, context_shape + else: + return inputs_shape diff --git a/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py index 3f058addb7..9cfd6d4ff6 100644 --- a/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py +++ b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py @@ -119,6 +119,15 @@ def __init__( self.output_channels = output_channels self.latent_shape = latent_shape + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + def get_config(self): config = super().get_config() config.update( From 339669f505b4e373477e8832052c926ba3465b20 Mon Sep 17 00:00:00 2001 From: Siva Sravana Kumar Neeli <113718461+sineeli@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:19:46 -0700 Subject: [PATCH 27/33] Add remaining bbox utils (#1804) * - Add formats, iou, utils for bounding box * - Add `AnchorGenerator`, `BoxMatcher` and `NonMaxSupression` layers * - Remove scope_name not required. * use default keras name scope * - Correct format error * - Remove layers as of now and keep them at model level till keras core supports them * - Correct api_gen --- keras_nlp/api/bounding_box/__init__.py | 13 + keras_nlp/src/bounding_box/formats.py | 162 +++++++++++ keras_nlp/src/bounding_box/iou.py | 263 ++++++++++++++++++ keras_nlp/src/bounding_box/iou_test.py | 161 +++++++++++ keras_nlp/src/bounding_box/utils.py | 194 +++++++++++++ keras_nlp/src/bounding_box/utils_test.py | 161 +++++++++++ .../src/bounding_box/validate_format_test.py | 47 ++++ 7 files changed, 1001 insertions(+) create mode 100644 keras_nlp/src/bounding_box/formats.py create mode 100644 keras_nlp/src/bounding_box/iou.py create mode 100644 keras_nlp/src/bounding_box/iou_test.py create mode 100644 keras_nlp/src/bounding_box/utils.py create mode 100644 keras_nlp/src/bounding_box/utils_test.py create mode 100644 keras_nlp/src/bounding_box/validate_format_test.py diff --git a/keras_nlp/api/bounding_box/__init__.py b/keras_nlp/api/bounding_box/__init__.py index 18be1cd9aa..8488f76e6f 100644 --- a/keras_nlp/api/bounding_box/__init__.py +++ b/keras_nlp/api/bounding_box/__init__.py @@ -18,6 +18,19 @@ """ from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.formats import CENTER_XYWH +from keras_nlp.src.bounding_box.formats import REL_XYWH +from keras_nlp.src.bounding_box.formats import REL_XYXY +from keras_nlp.src.bounding_box.formats import REL_YXYX +from keras_nlp.src.bounding_box.formats import XYWH +from keras_nlp.src.bounding_box.formats import XYXY +from keras_nlp.src.bounding_box.formats import YXYX +from keras_nlp.src.bounding_box.iou import compute_ciou +from keras_nlp.src.bounding_box.iou import compute_iou from keras_nlp.src.bounding_box.to_dense import to_dense from keras_nlp.src.bounding_box.to_ragged import to_ragged +from keras_nlp.src.bounding_box.utils import as_relative +from keras_nlp.src.bounding_box.utils import clip_boxes +from keras_nlp.src.bounding_box.utils import clip_to_image +from keras_nlp.src.bounding_box.utils import is_relative from keras_nlp.src.bounding_box.validate_format import validate_format diff --git a/keras_nlp/src/bounding_box/formats.py b/keras_nlp/src/bounding_box/formats.py new file mode 100644 index 0000000000..fda64a860e --- /dev/null +++ b/keras_nlp/src/bounding_box/formats.py @@ -0,0 +1,162 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +formats.py contains axis information for each supported format. +""" + +from keras_nlp.src.api_export import keras_nlp_export + + +@keras_nlp_export("keras_nlp.bounding_box.XYXY") +class XYXY: + """XYXY contains axis indices for the XYXY format. + + All values in the XYXY format should be absolute pixel values. + + The XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_XYXY") +class REL_XYXY: + """REL_XYXY contains axis indices for the REL_XYXY format. + + REL_XYXY is like XYXY, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.CENTER_XYWH") +class CENTER_XYWH: + """CENTER_XYWH contains axis indices for the CENTER_XYWH format. + + All values in the CENTER_XYWH format should be absolute pixel values. + + The CENTER_XYWH format consists of the following required indices: + + - X: X coordinate of the center of the bounding box + - Y: Y coordinate of the center of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.XYWH") +class XYWH: + """XYWH contains axis indices for the XYWH format. + + All values in the XYWH format should be absolute pixel values. + + The XYWH format consists of the following required indices: + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_XYWH") +class REL_XYWH: + """REL_XYWH contains axis indices for the XYWH format. + + REL_XYXY is like XYWH, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.YXYX") +class YXYX: + """YXYX contains axis indices for the YXYX format. + + All values in the YXYX format should be absolute pixel values. + + The YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_YXYX") +class REL_YXYX: + """REL_YXYX contains axis indices for the REL_YXYX format. + + REL_YXYX is like YXYX, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 diff --git a/keras_nlp/src/bounding_box/iou.py b/keras_nlp/src/bounding_box/iou.py new file mode 100644 index 0000000000..46ea2a34b4 --- /dev/null +++ b/keras_nlp/src/bounding_box/iou.py @@ -0,0 +1,263 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains functions to compute ious of bounding boxes.""" +import math + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.utils import as_relative +from keras_nlp.src.bounding_box.utils import is_relative + + +def _compute_area(box): + """Computes area for bounding boxes + + Args: + box: [N, 4] or [batch_size, N, 4] float Tensor, either batched + or unbatched boxes. + Returns: + a float Tensor of [N] or [batch_size, N] + """ + y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1) + return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1) + + +def _compute_intersection(boxes1, boxes2): + """Computes intersection area between two sets of boxes. + + Args: + boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes. + boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes. + Returns: + a [N, M] or [batch_size, N, M] float Tensor. + """ + y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + boxes2_rank = len(boxes2.shape) + perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1] + # [N, M] or [batch_size, N, M] + intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm)) + intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm)) + intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm)) + intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm)) + + intersect_height = intersect_ymax - intersect_ymin + intersect_width = intersect_xmax - intersect_xmin + zeros_t = ops.cast(0, intersect_height.dtype) + intersect_height = ops.maximum(zeros_t, intersect_height) + intersect_width = ops.maximum(zeros_t, intersect_width) + + return intersect_height * intersect_width + + +@keras_nlp_export("keras_nlp.bounding_box.compute_iou") +def compute_iou( + boxes1, + boxes2, + bounding_box_format, + use_masking=False, + mask_val=-1, + images=None, + image_shape=None, +): + """Computes a lookup table vector containing the ious for a given set boxes. + + The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if + boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the + boxes are batched. + + The users can pass `boxes1` and `boxes2` to be different ranks. For example: + 1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N]. + 2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return + [batch_size, M, N] + 3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N] + 4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N] + + Args: + boxes1: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + boxes2: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + bounding_box_format: a case-insensitive string which is one of `"xyxy"`, + `"rel_xyxy"`, `"xyWH"`, `"center_xyWH"`, `"yxyx"`, `"rel_yxyx"`. + For detailed information on the supported format, see the + [KerasCV bounding box documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + use_masking: whether masking will be applied. This will mask all `boxes1` + or `boxes2` that have values less than 0 in all its 4 dimensions. + Default to `False`. + mask_val: int to mask those returned IOUs if the masking is True, defaults + to -1. + + Returns: + iou_lookup_table: a vector containing the pairwise ious of boxes1 and + boxes2. + """ # noqa: E501 + + boxes1_rank = len(boxes1.shape) + boxes2_rank = len(boxes2.shape) + + if boxes1_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes1 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes1.shape)=2 AND or len(boxes1.shape)=3." + ) + if boxes2_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes2 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes2.shape)=2 AND or len(boxes2.shape)=3." + ) + + target_format = "yxyx" + if is_relative(bounding_box_format): + target_format = as_relative(target_format) + + boxes1 = convert_format( + boxes1, + source=bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + + boxes2 = convert_format( + boxes2, + source=bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + + intersect_area = _compute_intersection(boxes1, boxes2) + boxes1_area = _compute_area(boxes1) + boxes2_area = _compute_area(boxes2) + boxes2_area_rank = len(boxes2_area.shape) + boxes2_axis = 1 if (boxes2_area_rank == 2) else 0 + boxes1_area = ops.expand_dims(boxes1_area, axis=-1) + boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis) + union_area = boxes1_area + boxes2_area - intersect_area + res = ops.divide(intersect_area, union_area + keras.backend.epsilon()) + + if boxes1_rank == 2: + perm = [1, 0] + else: + perm = [0, 2, 1] + + if not use_masking: + return res + + mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res) + boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0) + boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0) + background_mask = ops.logical_or( + boxes1_mask, ops.transpose(boxes2_mask, perm) + ) + iou_lookup_table = ops.where(background_mask, mask_val_t, res) + return iou_lookup_table + + +@keras_nlp_export("keras_nlp.bounding_box.compute_ciou") +def compute_ciou(boxes1, boxes2, bounding_box_format): + """ + Computes the Complete IoU (CIoU) between two bounding boxes or between + two batches of bounding boxes. + + CIoU loss is an extension of GIoU loss, which further improves the IoU + optimization for object detection. CIoU loss not only penalizes the + bounding box coordinates but also considers the aspect ratio and center + distance of the boxes. The length of the last dimension should be 4 to + represent the bounding boxes. + + Args: + box1 (tensor): tensor representing the first bounding box with + shape (..., 4). + box2 (tensor): tensor representing the second bounding box with + shape (..., 4). + bounding_box_format: a case-insensitive string (for example, "xyxy"). + Each bounding box is defined by these 4 values. For detailed + information on the supported formats, see the [KerasCV bounding box + documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + + Returns: + tensor: The CIoU distance between the two bounding boxes. + """ + target_format = "xyxy" + if is_relative(bounding_box_format): + target_format = as_relative(target_format) + + boxes1 = convert_format( + boxes1, source=bounding_box_format, target=target_format + ) + + boxes2 = convert_format( + boxes2, source=bounding_box_format, target=target_format + ) + + x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + + width_1 = x_max1 - x_min1 + height_1 = y_max1 - y_min1 + keras.backend.epsilon() + width_2 = x_max2 - x_min2 + height_2 = y_max2 - y_min2 + keras.backend.epsilon() + + intersection_area = ops.maximum( + ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0 + ) * ops.maximum( + ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0 + ) + union_area = ( + width_1 * height_1 + + width_2 * height_2 + - intersection_area + + keras.backend.epsilon() + ) + iou = ops.squeeze( + ops.divide(intersection_area, union_area + keras.backend.epsilon()), + axis=-1, + ) + + convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2) + convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2) + convex_diagonal_squared = ops.squeeze( + convex_width**2 + convex_height**2 + keras.backend.epsilon(), + axis=-1, + ) + centers_distance_squared = ops.squeeze( + ((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2 + + ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2, + axis=-1, + ) + + v = ops.squeeze( + ops.power( + (4 / math.pi**2) + * (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)), + 2, + ), + axis=-1, + ) + alpha = v / (v - iou + (1 + keras.backend.epsilon())) + + return iou - ( + centers_distance_squared / convex_diagonal_squared + v * alpha + ) diff --git a/keras_nlp/src/bounding_box/iou_test.py b/keras_nlp/src/bounding_box/iou_test.py new file mode 100644 index 0000000000..ffd3b61cf3 --- /dev/null +++ b/keras_nlp/src/bounding_box/iou_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for iou functions.""" + +import numpy as np + +from keras_nlp.src.bounding_box import iou as iou_lib +from keras_nlp.src.tests.test_case import TestCase + + +class IoUTest(TestCase): + def test_compute_single_iou(self): + bb1 = np.array([[100, 101, 200, 201]]) + bb1_off_by_1 = np.array([[101, 102, 201, 202]]) + # area of bb1 and bb1_off_by_1 are each 10000. + # intersection area is 99*99=9801 + # iou=9801/(2*10000 - 9801)=0.96097656633 + self.assertAllClose( + iou_lib.compute_iou(bb1, bb1_off_by_1, "yxyx")[0], [0.96097656633] + ) + + def test_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=np.float32, + ) + + sample_y_true = np.array([bb1, top_left_bounding_box, far_away_box]) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_boxes1_unbatched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_unbatched_boxes1_batched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) diff --git a/keras_nlp/src/bounding_box/utils.py b/keras_nlp/src/bounding_box/utils.py new file mode 100644 index 0000000000..a96c284a6c --- /dev/null +++ b/keras_nlp/src/bounding_box/utils.py @@ -0,0 +1,194 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for working with bounding boxes.""" + +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.bounding_box import converters +from keras_nlp.src.bounding_box.formats import XYWH + + +@keras_nlp_export("keras_nlp.bounding_box.is_relative") +def is_relative(bounding_box_format): + """A util to check if a bounding box format uses relative coordinates""" + if bounding_box_format.lower() not in converters.TO_XYXY_CONVERTERS: + raise ValueError( + "`is_relative()` received an unsupported format for the argument " + f"`bounding_box_format`. `bounding_box_format` should be one of " + f"{converters.TO_XYXY_CONVERTERS.keys()}. " + f"Got bounding_box_format={bounding_box_format}" + ) + + return bounding_box_format.startswith("rel") + + +@keras_nlp_export("keras_nlp.bounding_box.as_relative") +def as_relative(bounding_box_format): + """A util to get the relative equivalent of a provided bounding box format. + + If the specified format is already a relative format, + it will be returned unchanged. + """ + + if not is_relative(bounding_box_format): + return "rel_" + bounding_box_format + + return bounding_box_format + + +def _relative_area(boxes, bounding_box_format): + boxes = converters.convert_format( + boxes, + source=bounding_box_format, + target="rel_xywh", + ) + widths = boxes[..., XYWH.WIDTH] + heights = boxes[..., XYWH.HEIGHT] + # handle corner case where shear performs a full inversion. + return ops.where( + ops.logical_and(widths > 0, heights > 0), widths * heights, 0.0 + ) + + +@keras_nlp_export("keras_nlp.bounding_box.clip_to_image") +def clip_to_image( + bounding_boxes, bounding_box_format, images=None, image_shape=None +): + """clips bounding boxes to image boundaries. + + `clip_to_image()` clips bounding boxes that have coordinates out of bounds + of an image down to the boundaries of the image. This is done by converting + the bounding box to relative formats, then clipping them to the `[0, 1]` + range. Additionally, bounding boxes that end up with a zero area have their + class ID set to -1, indicating that there is no object present in them. + + Args: + bounding_boxes: bounding box tensor to clip. + bounding_box_format: the KerasCV bounding box format the bounding boxes + are in. + images: list of images to clip the bounding boxes to. + image_shape: the shape of the images to clip the bounding boxes to. + """ + boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"] + + boxes = converters.convert_format( + boxes, + source=bounding_box_format, + target="rel_xyxy", + images=images, + image_shape=image_shape, + ) + boxes, classes, images, squeeze = _format_inputs(boxes, classes, images) + x1, y1, x2, y2 = ops.split(boxes, 4, axis=-1) + clipped_bounding_boxes = ops.concatenate( + [ + ops.clip(x1, 0, 1), + ops.clip(y1, 0, 1), + ops.clip(x2, 0, 1), + ops.clip(y2, 0, 1), + ], + axis=-1, + ) + areas = _relative_area( + clipped_bounding_boxes, bounding_box_format="rel_xyxy" + ) + clipped_bounding_boxes = converters.convert_format( + clipped_bounding_boxes, + source="rel_xyxy", + target=bounding_box_format, + images=images, + image_shape=image_shape, + ) + clipped_bounding_boxes = ops.where( + ops.expand_dims(areas > 0.0, axis=-1), clipped_bounding_boxes, -1.0 + ) + classes = ops.where(areas > 0.0, classes, -1) + nan_indices = ops.any(ops.isnan(clipped_bounding_boxes), axis=-1) + classes = ops.where(nan_indices, -1, classes) + + # TODO update dict and return + clipped_bounding_boxes, classes = _format_outputs( + clipped_bounding_boxes, classes, squeeze + ) + + bounding_boxes.update({"boxes": clipped_bounding_boxes, "classes": classes}) + + return bounding_boxes + + +@keras_nlp_export("keras_nlp.bounding_box.clip_boxes") +def clip_boxes(boxes, image_shape): + """Clip boxes to the boundaries of the image shape""" + if boxes.shape[-1] != 4: + raise ValueError( + "boxes.shape[-1] is {:d}, but must be 4.".format(boxes.shape[-1]) + ) + + if isinstance(image_shape, list) or isinstance(image_shape, tuple): + height, width, _ = image_shape + max_length = ops.stack([height, width, height, width], axis=-1) + else: + image_shape = ops.cast(image_shape, dtype=boxes.dtype) + height = image_shape[0] + width = image_shape[1] + max_length = ops.stack([height, width, height, width], axis=-1) + + clipped_boxes = ops.maximum(ops.minimum(boxes, max_length), 0.0) + return clipped_boxes + + +def _format_inputs(boxes, classes, images): + boxes_rank = len(boxes.shape) + if boxes_rank > 3: + raise ValueError( + "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " + f"len(boxes.shape)={boxes_rank}" + ) + boxes_includes_batch = boxes_rank == 3 + # Determine if images needs an expand_dims() call + if images is not None: + images_rank = len(images.shape) + if images_rank > 4: + raise ValueError( + "Expected len(images.shape)=2, or len(images.shape)=3, got " + f"len(images.shape)={images_rank}" + ) + images_include_batch = images_rank == 4 + if boxes_includes_batch != images_include_batch: + raise ValueError( + "clip_to_image() expects both boxes and images to be batched, " + "or both boxes and images to be unbatched. Received " + f"len(boxes.shape)={boxes_rank}, " + f"len(images.shape)={images_rank}. Expected either " + "len(boxes.shape)=2 AND len(images.shape)=3, or " + "len(boxes.shape)=3 AND len(images.shape)=4." + ) + if not images_include_batch: + images = ops.expand_dims(images, axis=0) + + if not boxes_includes_batch: + return ( + ops.expand_dims(boxes, axis=0), + ops.expand_dims(classes, axis=0), + images, + True, + ) + return boxes, classes, images, False + + +def _format_outputs(boxes, classes, squeeze): + if squeeze: + return ops.squeeze(boxes, axis=0), ops.squeeze(classes, axis=0) + return boxes, classes diff --git a/keras_nlp/src/bounding_box/utils_test.py b/keras_nlp/src/bounding_box/utils_test.py new file mode 100644 index 0000000000..fcbab8dae8 --- /dev/null +++ b/keras_nlp/src/bounding_box/utils_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from keras import ops + +from keras_nlp.src.bounding_box import utils +from keras_nlp.src.tests.test_case import TestCase + + +class BoundingBoxUtilTest(TestCase): + def test_clip_to_image_standard(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[200, 200, 400, 400], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + boxes = bounding_boxes["boxes"] + self.assertAllGreaterEqual(boxes, 0) + ( + x1, + y1, + x2, + y2, + ) = ops.split(boxes, 4, axis=1) + self.assertAllLessEqual(ops.concatenate([x1, x2], axis=1), width) + self.assertAllLessEqual(ops.concatenate([y1, y2], axis=1), height) + # Test relative format batched + image = ops.ones(shape=(1, height, width, 3)) + + bounding_boxes = { + "boxes": np.array([[[0.2, -1, 1.2, 0.3], [0.4, 1.5, 0.2, 0.3]]]), + "classes": np.array([[0, 0]]), + } + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="rel_xyxy", images=image + ) + self.assertAllLessEqual(bounding_boxes["boxes"], 1) + + def test_clip_to_image_filters_fully_out_bounding_boxes(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[257, 257, 400, 400], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + + self.assertAllEqual( + bounding_boxes["boxes"], + np.array([[-1, -1, -1, -1], [100, 100, 256, 256]]), + ), + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_clip_to_image_filters_fully_out_bounding_boxes_negative_area(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[110, 120, 100, 100], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + self.assertAllEqual( + bounding_boxes["boxes"], + np.array( + [ + [ + -1, + -1, + -1, + -1, + ], + [ + 100, + 100, + 256, + 256, + ], + ] + ), + ) + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_clip_to_image_filters_nans(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array( + [[0, float("NaN"), 100, 100], [100, 100, 300, 300]] + ), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + self.assertAllEqual( + bounding_boxes["boxes"], + np.array( + [ + [ + -1, + -1, + -1, + -1, + ], + [ + 100, + 100, + 256, + 256, + ], + ] + ), + ) + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_is_relative_util(self): + self.assertTrue(utils.is_relative("rel_xyxy")) + self.assertFalse(utils.is_relative("xyxy")) + + with self.assertRaises(ValueError): + _ = utils.is_relative("bad_format") + + def test_as_relative_util(self): + self.assertEqual(utils.as_relative("yxyx"), "rel_yxyx") + self.assertEqual(utils.as_relative("rel_xywh"), "rel_xywh") diff --git a/keras_nlp/src/bounding_box/validate_format_test.py b/keras_nlp/src/bounding_box/validate_format_test.py new file mode 100644 index 0000000000..020279f334 --- /dev/null +++ b/keras_nlp/src/bounding_box/validate_format_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.bounding_box import validate_format +from keras_nlp.src.tests.test_case import TestCase + + +class ValidateTest(TestCase): + def test_raises_nondict(self): + with self.assertRaisesRegex( + ValueError, "Expected `bounding_boxes` to be a dictionary, got " + ): + validate_format.validate_format(np.ones((4, 3, 6))) + + def test_mismatch_dimensions(self): + with self.assertRaisesRegex( + ValueError, + "Expected `boxes` and `classes` to have matching dimensions", + ): + validate_format.validate_format( + {"boxes": np.ones((4, 3, 6)), "classes": np.ones((4, 6))} + ) + + def test_bad_keys(self): + with self.assertRaisesRegex(ValueError, "containing keys"): + validate_format.validate_format( + { + "box": [ + 1, + 2, + 3, + ], + "class": [1234], + } + ) From f31ad9c6d1594c7deae732f3f3b0880fe771a0c8 Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 13:47:57 -0700 Subject: [PATCH 28/33] Add Deeplabv3 and v3plus in the same backbone and segmenter --- keras_nlp/api/models/__init__.py | 9 +- keras_nlp/api/models/segmentation/__init__.py | 22 -- .../__init__.py | 0 .../deeplab_v3_backbone.py} | 190 +++++++++--------- .../deeplab_v3_backbone_test.py} | 44 ++-- .../deeplab_v3_layers.py} | 90 ++++++--- .../models/deeplab_v3/deeplab_v3_segmenter.py | 115 +++++++++++ .../deeplab_v3/deeplab_v3_segmenter_test.py | 77 +++++++ keras_nlp/src/models/image_segmenter.py | 105 ++++++++++ 9 files changed, 480 insertions(+), 172 deletions(-) delete mode 100644 keras_nlp/api/models/segmentation/__init__.py rename keras_nlp/src/models/{deeplab_v3_plus => deeplab_v3}/__init__.py (100%) rename keras_nlp/src/models/{deeplab_v3_plus/deeplab_v3_plus_segmenter.py => deeplab_v3/deeplab_v3_backbone.py} (54%) rename keras_nlp/src/models/{deeplab_v3_plus/deeplab_v3_plus_test.py => deeplab_v3/deeplab_v3_backbone_test.py} (50%) rename keras_nlp/src/models/{deeplab_v3_plus/deeplab_v3_plus_layers.py => deeplab_v3/deeplab_v3_layers.py} (70%) create mode 100644 keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py create mode 100644 keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter_test.py create mode 100644 keras_nlp/src/models/image_segmenter.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index be74ec6370..7d985b5c49 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -17,7 +17,6 @@ since your modifications would be overwritten. """ -from keras_nlp.api.models import segmentation from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.models.albert.albert_classifier import AlbertClassifier from keras_nlp.src.models.albert.albert_masked_lm import AlbertMaskedLM @@ -75,8 +74,11 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) -from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( - DeepLabV3Plus, +from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_nlp.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, ) from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_nlp.src.models.densenet.densenet_image_classifier import ( @@ -148,6 +150,7 @@ ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.image_segmenter import Segmenter from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( diff --git a/keras_nlp/api/models/segmentation/__init__.py b/keras_nlp/api/models/segmentation/__init__.py deleted file mode 100644 index c1da3bb01a..0000000000 --- a/keras_nlp/api/models/segmentation/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""DO NOT EDIT. - -This file was autogenerated. Do not edit it by hand, -since your modifications would be overwritten. -""" - -from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( - DeepLabV3Plus, -) diff --git a/keras_nlp/src/models/deeplab_v3_plus/__init__.py b/keras_nlp/src/models/deeplab_v3/__init__.py similarity index 100% rename from keras_nlp/src/models/deeplab_v3_plus/__init__.py rename to keras_nlp/src/models/deeplab_v3/__init__.py diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py similarity index 54% rename from keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py rename to keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py index c41b6b9cb7..9d6e2f6c25 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_segmenter.py +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -14,138 +14,131 @@ import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_layers import ( +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.deeplab_v3.deeplab_v3_layers import ( SpatialPyramidPooling, ) -from keras_nlp.src.models.segmentation import Segmentation -@keras_nlp_export( - [ - "keras_nlp.models.DeepLabV3Plus", - "keras_nlp.models.segmentation.DeepLabV3Plus", - ] -) -class DeepLabV3Plus(Segmentation): - """DeepLabV3+ architecture for semantic segmentation. +@keras_nlp_export("keras_nlp.models.DeepLabV3Backbone") +class DeepLabV3Backbone(Backbone): + """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. - This class implements a DeepLabV3+ architecture as described in + This class implements a DeepLabV3 & DeepLabV3Plus architecture as described in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)(ECCV 2018) and [Rethinking Atrous Convolution for Semantic Image Segmentation]( https://arxiv.org/abs/1706.05587)(CVPR 2017) Args: - backbone: `keras.Model`. The backbone network for the model that is - used as a feature extractor for the DeepLabV3+ Encoder. Should + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the Encoder. Should either be a `keras_nlp.models.backbones.backbone.Backbone` or a `keras.Model` that implements the `pyramid_outputs` property with keys "P2", "P3" etc as values. A somewhat sensible backbone to use in many cases is the `keras_nlp.models.ResNetBackbone.from_preset("resnet_v2_50")`. - num_classes: int, the number of classes for the detection model. Note - that the `num_classes` contains the background class, and the - classes from the data should be represented by integers with range - [0, `num_classes`). projection_filters: int, number of filters in the convolution layer projecting low-level features from the `backbone`. - low_level_feature_key: str, layer level to extract the feature from one of the - key from the `backbone` `pyramid_outputs` - property such as "P2", "P3" etc. spatial_pyramid_pooling_key: str, layer level to extract and perform - `spatial_pyramid_pooling`, one of the key from the `backbone` `pyramid_outputs` + `spatial_pyramid_pooling`, one of the key from the `backbone` + `pyramid_outputs` property such as "P4", "P5" etc. + upsampling_size: Int, or tuple of 2 integers. The upsampling factors for + rows and columns of `spatial_pyramid_pooling` layer. + If `low_level_feature_key` is given then `spatial_pyramid_pooling`s + layer resolution should match with the `low_level_feature`s layer + resolution to concatenate both the layers for combined encoder outputs. + dilation_rates: A `list` of integers for parallel dilated conv. + Applied only when Default `SpatialPyramidPooling` is used. Usually a + sample choice of rates are [6, 12, 18]. + low_level_feature_key: (Optional) str, layer level to extract the feature + from one of the key from the `backbone`s `pyramid_outputs` + property such as "P2", "P3" etc which will be the Decoder block. + Required only when the DeepLabV3Plus architecture needs to be applied. + activation: str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. spatial_pyramid_pooling: (Optional) a `keras.layers.Layer`. Also known as Atrous Spatial Pyramid Pooling (ASPP). Performs spatial pooling on different spatial levels in the pyramid, with dilation. If provided, the feature map from the backbone is passed to it inside the DeepLabV3 Encoder, otherwise SpatialPyramidPooling layer is used. - dilation_rates: (Optional) A `list` of integers for parallel dilated conv. - Applied only when Default `SpatialPyramidPooling` is used. Usually a - sample choice of rates are [6, 12, 18]. segmentation_head: (Optional) a `keras.layers.Layer`. If provided, the outputs of the DeepLabV3 encoder is passed to this layer and it - should predict the segmentation mask based on feature from backbone - and feature from decoder, otherwise a default DeepLabV3 - convolutional head is used. + will be considered as the last layer before final segmentaion layer , + otherwise a default DeepLabV3 convolutional head is used. Example: ```python - images = np.ones(shape=(1, 96, 96, 3)) - labels = np.zeros(shape=(1, 96, 96, 1)) - backbone = keras_nlp.models.ResNetBackbone.from_preset("resnet_v2_50") + image_encoder = keras_nlp.models.ResNetBackbone.from_preset("resnet_v2_50") - model = keras_hub.models.DeepLabV3Plus( - backbone= backbone, - num_classes=3, + model = keras_nlp.models.DeepLabV3Backbone( + image_encoder= image_encoder, projection_filters=48, low_level_feature_key="P2", spatial_pyramid_pooling_key="P5", ) - - # Evaluate model - model(images) - - # Train model - model.compile( - optimizer="adam", - loss=keras.losses.BinaryCrossentropy(from_logits=True), - metrics=["accuracy"], - ) - model.fit(images, labels, epochs=3) ``` """ def __init__( self, - backbone, - num_classes, - low_level_feature_key, + image_encoder, spatial_pyramid_pooling_key, + upsampling_size, + dilation_rates, + low_level_feature_key=None, projection_filters=48, spatial_pyramid_pooling=None, - dilation_rates=None, segmentation_head=None, **kwargs, ): - if not isinstance(backbone, keras.layers.Layer) or not isinstance( - backbone, keras.Model - ): + if not isinstance(image_encoder, keras.Model): raise ValueError( - "Argument `backbone` must be a `keras.layers.Layer` instance " - f" or `keras.Model`. Received instead " - f"backbone={backbone} (of type {type(backbone)})." + "Argument `image_encoder` must be a `keras.Model` instance. Received instead " + f"backbone={image_encoder} (of type {type(image_encoder)})." ) - + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 # === Functional Model === - inputs = backbone.input + inputs = keras.layers.Input((None, None, 3)) + + fpn_model = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + fpn_outputs = fpn_model(inputs) if spatial_pyramid_pooling is None: spatial_pyramid_pooling = SpatialPyramidPooling( dilation_rates=dilation_rates ) - spatial_backbone_features = backbone.pyramid_outputs[ - spatial_pyramid_pooling_key - ] + spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) - low_level_backbone_feature = backbone.pyramid_outputs[ - low_level_feature_key - ] - low_level_projected_features = apply_low_level_feature_network( - low_level_backbone_feature, projection_filters - ) - encoder_outputs = keras.layers.UpSampling2D( - size=(8, 8), + size=upsampling_size, interpolation="bilinear", name="encoder_output_upsampling", + data_format=data_format, )(spp_outputs) - combined_encoder_outputs = keras.layers.Concatenate(axis=-1)( - [encoder_outputs, low_level_projected_features] - ) + if low_level_feature_key: + decoder_feature = fpn_outputs[low_level_feature_key] + low_level_projected_features = apply_low_level_feature_network( + decoder_feature, projection_filters, channel_axis + ) + encoder_outputs = keras.layers.Concatenate(axis=channel_axis)( + [encoder_outputs, low_level_projected_features] + ) + # upsampling to the original image size + upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( + int(upsampling_size[0]) + if isinstance(upsampling_size, tuple) + else upsampling_size + ) if segmentation_head is None: x = keras.layers.Conv2D( name="segmentation_head_conv", @@ -153,36 +146,27 @@ def __init__( kernel_size=1, padding="same", use_bias=False, - )(combined_encoder_outputs) - x = keras.layers.BatchNormalization(name="segmentation_head_norm")( - x - ) + data_format=data_format, + )(encoder_outputs) + x = keras.layers.BatchNormalization( + name="segmentation_head_norm", axis=channel_axis + )(x) x = keras.layers.ReLU(name="segmentation_head_relu")(x) x = keras.layers.UpSampling2D( - size=(4, 4), interpolation="bilinear" - )(x) - # Classification layer - outputs = keras.layers.Conv2D( - name="segmentation_output", - filters=num_classes, - kernel_size=1, - use_bias=False, - padding="same", - # Force the dtype of the classification layer to float32 - # to avoid the NAN loss issue when used with mixed - # precision API. - dtype="float32", + size=upsampling, + interpolation="bilinear", + data_format=data_format, )(x) else: - outputs = segmentation_head(combined_encoder_outputs) + x = segmentation_head(encoder_outputs) - super().__init__(inputs=inputs, outputs=outputs, **kwargs) + super().__init__(inputs=inputs, outputs=x, **kwargs) # === Config === - self.num_classes = num_classes - self.backbone = backbone + self.image_encoder = image_encoder self.spatial_pyramid_pooling = spatial_pyramid_pooling self.projection_filters = projection_filters + self.upsampling_size = upsampling_size self.segmentation_head = segmentation_head self.dilation_rates = dilation_rates self.low_level_feature_key = low_level_feature_key @@ -190,8 +174,9 @@ def __init__( def get_config(self): return { - "num_classes": self.num_classes, - "backbone": keras.saving.serialize_keras_object(self.backbone), + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), "spatial_pyramid_pooling": keras.saving.serialize_keras_object( self.spatial_pyramid_pooling ), @@ -200,14 +185,19 @@ def get_config(self): self.segmentation_head ), "dilation_rates": self.dilation_rates, + "upsampling_size": self.upsampling_size, "low_level_feature_key": self.low_level_feature_key, "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, } @classmethod def from_config(cls, config): - if "backbone" in config and isinstance(config["backbone"], dict): - config["backbone"] = keras.layers.deserialize(config["backbone"]) + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) if "spatial_pyramid_pooling" in config and isinstance( config["spatial_pyramid_pooling"], dict ): @@ -223,15 +213,21 @@ def from_config(cls, config): return super().from_config(config) -def apply_low_level_feature_network(input_tensor, projection_filters): +def apply_low_level_feature_network( + input_tensor, projection_filters, channel_axis +): + data_format = keras.config.image_data_format() x = keras.layers.Conv2D( name="low_level_feature_conv", filters=projection_filters, kernel_size=1, padding="same", use_bias=False, + data_format=data_format, )(input_tensor) - x = keras.layers.BatchNormalization(name="low_level_feature_norm")(x) + x = keras.layers.BatchNormalization( + name="low_level_feature_norm", axis=channel_axis + )(x) x = keras.layers.ReLU(name="low_level_feature_relu")(x) return x diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone_test.py similarity index 50% rename from keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py rename to keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone_test.py index 641c402309..a7919ffc0e 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_test.py +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -15,41 +15,49 @@ import numpy as np import pytest -from keras_nlp.src.models.deeplab_v3_plus.deeplab_v3_plus_segmenter import ( - DeepLabV3Plus, +from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, ) from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone from keras_nlp.src.tests.test_case import TestCase -class DeepLabV3PlusTest(TestCase): +class DeepLabV3Test(TestCase): def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "pooling": "avg", + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) self.init_kwargs = { - "backbone": ResNetBackbone.from_preset( - "hf://timm/resnet18.a1_in1k" - ), - "num_classes": 2, + "image_encoder": self.image_encoder, "low_level_feature_key": "P2", - "spatial_pyramid_pooling_key": "P5", - "projection_filters": 48, - "spatial_pyramid_pooling": None, + "spatial_pyramid_pooling_key": "P4", "dilation_rates": [6, 12, 18], - "segmentation_head": None, + "upsampling_size": 4, } - self.images = np.ones((2, 96, 96, 3), dtype="float32") - self.labels = np.zeros((2, 96, 96, 2), dtype="float32") + self.input_data = np.ones((2, 96, 96, 3), dtype="float32") def test_segmentation_basics(self): - self.run_segmentation_test( - cls=DeepLabV3Plus, + self.run_vision_backbone_test( + cls=DeepLabV3Backbone, init_kwargs=self.init_kwargs, - train_data=(self.images, self.labels), - expected_output_shape=(2, 96, 96, 2), + input_data=self.input_data, + expected_output_shape=(2, 96, 96, 256), + run_mixed_precision_check=False, + run_quantization_check=False, ) @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( - cls=DeepLabV3Plus, + cls=DeepLabV3Backbone, init_kwargs=self.init_kwargs, + input_data=self.input_data, ) diff --git a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_layers.py similarity index 70% rename from keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py rename to keras_nlp/src/models/deeplab_v3/deeplab_v3_layers.py index ba46a11652..d2eb4b4e56 100644 --- a/keras_nlp/src/models/deeplab_v3_plus/deeplab_v3_plus_layers.py +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_layers.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any -from typing import List -from typing import Mapping - import keras from keras import ops @@ -41,10 +37,10 @@ class SpatialPyramidPooling(keras.layers.Layer): def __init__( self, - dilation_rates: List[int], - num_channels: int = 256, - activation: str = "relu", - dropout: float = 0.0, + dilation_rates, + num_channels=256, + activation="relu", + dropout=0.0, **kwargs, ): """Initializes an Atrous Spatial Pyramid Pooling layer. @@ -59,6 +55,8 @@ def __init__( which means no dropout is applied to the output. **kwargs: Additional keyword arguments to be passed. """ + self.data_format = keras.config.image_data_format() + self.channel_axis = -1 if self.data_format == "channels_last" else 1 super().__init__(**kwargs) self.dilation_rates = dilation_rates self.num_channels = num_channels @@ -66,7 +64,7 @@ def __init__( self.dropout = dropout def build(self, input_shape): - channels = input_shape[3] + channels = input_shape[self.channel_axis] # This is the parallel networks that process the input features with # different dilation rates. The output from each channel will be merged @@ -80,8 +78,9 @@ def build(self, input_shape): filters=self.num_channels, kernel_size=(1, 1), use_bias=False, + data_format=self.data_format, ), - keras.layers.BatchNormalization(), + keras.layers.BatchNormalization(axis=self.channel_axis), keras.layers.Activation(self.activation), ] ) @@ -99,8 +98,9 @@ def build(self, input_shape): padding="same", dilation_rate=dilation_rate, use_bias=False, + data_format=self.data_format, ), - keras.layers.BatchNormalization(), + keras.layers.BatchNormalization(axis=self.channel_axis), keras.layers.Activation(self.activation), ] ) @@ -108,16 +108,23 @@ def build(self, input_shape): self.aspp_parallel_channels.append(conv_sequential) # Last channel is the global average pooling with conv2D 1x1 kernel. + if self.channel_axis == -1: + reshape = keras.layers.Reshape((1, 1, channels)) + else: + reshape = keras.layers.Reshape((channels, 1, 1)) pool_sequential = keras.Sequential( [ - keras.layers.GlobalAveragePooling2D(), - keras.layers.Reshape((1, 1, channels)), + keras.layers.GlobalAveragePooling2D( + data_format=self.data_format + ), + reshape, keras.layers.Conv2D( filters=self.num_channels, kernel_size=(1, 1), use_bias=False, + data_format=self.data_format, ), - keras.layers.BatchNormalization(), + keras.layers.BatchNormalization(axis=self.channel_axis), keras.layers.Activation(self.activation), ] ) @@ -131,8 +138,9 @@ def build(self, input_shape): filters=self.num_channels, kernel_size=(1, 1), use_bias=False, + data_format=self.data_format, ), - keras.layers.BatchNormalization(), + keras.layers.BatchNormalization(axis=self.channel_axis), keras.layers.Activation(self.activation), keras.layers.Dropout(rate=self.dropout), ], @@ -147,10 +155,10 @@ def call(self, inputs, training=None): """Calls the Atrous Spatial Pyramid Pooling layer on an input. Args: - inputs: A tensor of shape [batch, height, width, channels] + inputs: A tensor of shape [batch, height, width, channels] Returns: - A tensor of shape [batch, height, width, num_channels] + A tensor of shape [batch, height, width, num_channels] """ result = [] @@ -159,26 +167,44 @@ def call(self, inputs, training=None): result.append(temp) image_shape = ops.shape(inputs) - height, width = image_shape[1], image_shape[2] - result[-1] = keras.layers.Resizing( + if self.channel_axis == -1: + height, width = image_shape[1], image_shape[2] + else: + height, width = image_shape[2], image_shape[3] + result[self.channel_axis] = keras.layers.Resizing( height, width, interpolation="bilinear", - )(result[-1]) + )(result[self.channel_axis]) - result = ops.concatenate(result, axis=-1) + result = ops.concatenate(result, axis=self.channel_axis) result = self.projection(result, training=training) return result def compute_output_shape(self, input_shape): - return tuple(input_shape[:-1]) + (self.num_channels,) - - def get_config(self) -> Mapping[str, Any]: - config = { - "dilation_rates": self.dilation_rates, - "num_channels": self.num_channels, - "activation": self.activation, - "dropout": self.dropout, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + if self.data_format == "channels_first": + return ( + input_shape[0], + self.num_channels, + input_shape[1], + input_shape[2], + ) + else: + return ( + input_shape[0], + input_shape[1], + input_shape[2], + self.num_channels, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "dilation_rates": self.dilation_rates, + "num_channels": self.num_channels, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py new file mode 100644 index 0000000000..3f07ff18b6 --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -0,0 +1,115 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_nlp.src.models.image_segmenter import Segmenter + + +@keras_nlp_export("keras_nlp.models.DeepLabV3ImageSegmenter") +class DeepLabV3ImageSegmenter(Segmenter): + """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task. + + Args: + backbone: A `keras_nlp.models.DeepLabV3` instance. + num_classes: int, the number of classes for the detection model. Note + that the `num_classes` contains the background class, and the + classes from the data should be represented by integers with range + [0, `num_classes`]. + activation: str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `None`. + + Example: + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + feature_pyramid_model = keras_nlp.models.DeepLabV3Backbone.from_preset("deeplabv3_resnet50") + + model = keras_hub.models.DeepLabV3ImageSegmenter( + num_classes=3, + projection_filters=48, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + ) + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + + backbone_cls = DeepLabV3Backbone + preprocessor_cls = None + + def __init__( + self, + backbone, + num_classes, + activation=None, + preprocessor=None, + **kwargs, + ): + data_format = keras.config.image_data_format() + # === Layers === + self.backbone = backbone + self.output_conv = keras.layers.Conv2D( + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + activation=activation, + data_format=data_format, + # Force the dtype of the classification layer to float32 + # to avoid the NAN loss issue when used with mixed + # precision API. + dtype="float32", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_conv(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter_test.py new file mode 100644 index 0000000000..56cc27bafb --- /dev/null +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -0,0 +1,77 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_nlp.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class DeepLabV3ImageSegmenterTest(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "pooling": "avg", + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.deeplab_backbone = DeepLabV3Backbone( + image_encoder=self.image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P4", + dilation_rates=[6, 12, 18], + upsampling_size=4, + ) + self.init_kwargs = { + "backbone": self.deeplab_backbone, + "num_classes": 3, + "activation": "softmax", + } + self.images = np.ones((2, 96, 96, 3), dtype="float32") + self.labels = np.zeros((2, 96, 96, 2), dtype="float32") + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(1, 96, 96, 1), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/models/image_segmenter.py b/keras_nlp/src/models/image_segmenter.py new file mode 100644 index 0000000000..a4937d94ff --- /dev/null +++ b/keras_nlp/src/models/image_segmenter.py @@ -0,0 +1,105 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.Segmenter") +class Segmenter(Task): + """Base class for all segmentation tasks. + + `Segmenter` tasks wrap a `keras_nlp.models.Backbone` to create a model + that can be used for segmentation. + `Segmenter` tasks take an additional + `num_classes` argument, the number of segmentation classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a image and `y` is a label from `[0, num_classes)`. + + All `Segmenter` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + model = keras_nlp.models.Segmenter.from_preset( + "basnet_resnet", + num_classes=2, + ) + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + output = model(images) + pred_labels = output[0] + + model.fit(images, labels, epochs=3) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Segmenter` task for training. + + The `Segmenter` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.BinaryCrossentropy` loss will be + applied for the segmentation task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.Accuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.BinaryCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.Accuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) From fc1a3a53f888d25d9af7cbe2bd864c50d48810fb Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 14:46:23 -0700 Subject: [PATCH 29/33] fix imports --- keras_nlp/api/models/__init__.py | 8 +++---- keras_nlp/src/utils/timm/convert.py | 37 ----------------------------- 2 files changed, 4 insertions(+), 41 deletions(-) delete mode 100644 keras_nlp/src/utils/timm/convert.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e8902b6a37..694e3d555d 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -96,16 +96,16 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) -from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone -from keras_nlp.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, -) from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( DeepLabV3Backbone, ) from keras_nlp.src.models.deeplab_v3.deeplab_v3_segmenter import ( DeepLabV3ImageSegmenter, ) +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_nlp/src/utils/timm/convert.py b/keras_nlp/src/utils/timm/convert.py deleted file mode 100644 index edfde3316b..0000000000 --- a/keras_nlp/src/utils/timm/convert.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert timm models to KerasNLP.""" - -from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone - - -def load_timm_backbone(cls, preset, load_weights, **kwargs): - """Load a timm model config and weights as a KerasNLP backbone. - - Args: - cls (class): Keras model class. - preset (str): Preset configuration name. - load_weights (bool): Whether to load the weights. - - Returns: - backbone: Initialized Keras model backbone. - """ - if cls is None: - raise ValueError("Backbone class is None") - if cls.__name__ == "ResNetBackbone": - return load_resnet_backbone(cls, preset, load_weights, **kwargs) - raise ValueError( - f"{cls} has not been ported from the Hugging Face format yet. " - "Please check Hugging Face Hub for the Keras model. " - ) From c172031d65d5379ebb7c04749b78a20a8943e2c0 Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 14:59:22 -0700 Subject: [PATCH 30/33] nit --- keras_nlp/api/models/__init__.py | 1 - keras_nlp/src/models/segmentation.py | 105 --------------------------- requirements-common.txt | 1 - 3 files changed, 107 deletions(-) delete mode 100644 keras_nlp/src/models/segmentation.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 694e3d555d..ccd4246993 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -256,7 +256,6 @@ RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_nlp.src.models.segmentation import Segmentation from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_nlp.src.models.t5.t5_backbone import T5Backbone diff --git a/keras_nlp/src/models/segmentation.py b/keras_nlp/src/models/segmentation.py deleted file mode 100644 index bba775667d..0000000000 --- a/keras_nlp/src/models/segmentation.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import keras - -from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.task import Task - - -@keras_nlp_export("keras_nlp.models.Segmentation") -class Segmentation(Task): - """Base class for all segmentation tasks. - - `Segmentation` tasks wrap a `keras_nlp.models.Backbone` to create a model - that can be used for segmentation. - `Segmentation` tasks take an additional - `num_classes` argument, the number of segmentation classes. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - labels where `x` is a image and `y` is a label from `[0, num_classes)`. - - All `Segmentation` tasks include a `from_preset()` constructor which can be - used to load a pre-trained config and weights. - - Example: - ```python - model = keras_nlp.models.Segmentation.from_preset( - "basnet_resnet", - num_classes=2, - ) - images = np.ones(shape=(1, 288, 288, 3)) - labels = np.zeros(shape=(1, 288, 288, 1)) - - output = model(images) - pred_labels = output[0] - - model.fit(images, labels, epochs=3) - ``` - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Default compilation. - self.compile() - - def compile( - self, - optimizer="auto", - loss="auto", - *, - metrics="auto", - **kwargs, - ): - """Configures the `Segmenter` task for training. - - The `Segmenter` task extends the default compilation signature of - `keras.Model.compile` with defaults for `optimizer`, `loss`, and - `metrics`. To override these defaults, pass any value - to these arguments during compilation. - - Args: - optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` - instance. Defaults to `"auto"`, which uses the default optimizer - for the given model and task. See `keras.Model.compile` and - `keras.optimizers` for more info on possible `optimizer` values. - loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. - Defaults to `"auto"`, where a - `keras.losses.BinaryCrossentropy` loss will be - applied for the segmentation task. See - `keras.Model.compile` and `keras.losses` for more info on - possible `loss` values. - metrics: `"auto"`, or a list of metrics to be evaluated by - the model during training and testing. Defaults to `"auto"`, - where a `keras.metrics.Accuracy` will be - applied to track the accuracy of the model during training. - See `keras.Model.compile` and `keras.metrics` for - more info on possible `metrics` values. - **kwargs: See `keras.Model.compile` for a full list of arguments - supported by the compile method. - """ - if optimizer == "auto": - optimizer = keras.optimizers.Adam(5e-5) - if loss == "auto": - activation = getattr(self, "activation", None) - activation = keras.activations.get(activation) - from_logits = activation != keras.activations.softmax - loss = keras.losses.BinaryCrossentropy(from_logits) - if metrics == "auto": - metrics = [keras.metrics.Accuracy()] - super().compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - **kwargs, - ) diff --git a/requirements-common.txt b/requirements-common.txt index 27596bfb13..4e90ca9fab 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -2,7 +2,6 @@ dm-tree regex rich -huggingface_hub kagglehub # Tooling deps. astor From 3b6c0454f319cd5b1e4b4079a12a670b491116ed Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 15:03:38 -0700 Subject: [PATCH 31/33] testcase changes --- .../models/deeplab_v3/deeplab_v3_backbone.py | 2 ++ keras_nlp/src/tests/test_case.py | 29 ------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py index 9d6e2f6c25..96d2ae03ea 100644 --- a/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -78,6 +78,8 @@ class DeepLabV3Backbone(Backbone): projection_filters=48, low_level_feature_key="P2", spatial_pyramid_pooling_key="P5", + upsampling_size = 8, + dilation_rates = [6, 12, 18] ) ``` """ diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 5b6fc098fc..7e2d0660b5 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -583,35 +583,6 @@ def run_task_test( with self.assertRaisesRegex(ValueError, "You must call `compile"): task.fit(ds) - def run_segmentation_test( - self, - cls, - init_kwargs, - train_data, - expected_output_shape=None, - batch_size=2, - ): - """Run basic tests for a backbone, including compilation.""" - task = cls(**init_kwargs) - # Check serialization (without a full save). - self.run_serialization_test(task) - ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size) - x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data) - - # Test predict. - output = task.predict(x) - if expected_output_shape is not None: - output_shape = tree.map_structure(lambda x: x.shape, output) - self.assertAllClose(output_shape, expected_output_shape) - # With a dataset. - output_ds = task.predict(ds) - self.assertAllClose(output, output_ds) - - # Test fit. - task.fit(x, y, sample_weight=sw) - # With a dataset. - task.fit(ds) - def run_preset_test( self, cls, From 704d119731ca70a8ea4d423341895e33a3e938e0 Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 15:10:12 -0700 Subject: [PATCH 32/33] Segmeter >> ImageSegmenter --- keras_nlp/api/models/__init__.py | 2 +- keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py | 4 ++-- keras_nlp/src/models/image_segmenter.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index ccd4246993..6b3624ce91 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -178,7 +178,7 @@ ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.image_classifier import ImageClassifier -from keras_nlp.src.models.image_segmenter import Segmenter +from keras_nlp.src.models.image_segmenter import ImageSegmenter as Segmenter from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py index 3f07ff18b6..5de2ec3d10 100644 --- a/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py +++ b/keras_nlp/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -18,11 +18,11 @@ from keras_nlp.src.models.deeplab_v3.deeplab_v3_backbone import ( DeepLabV3Backbone, ) -from keras_nlp.src.models.image_segmenter import Segmenter +from keras_nlp.src.models.image_segmenter import ImageSegmenter @keras_nlp_export("keras_nlp.models.DeepLabV3ImageSegmenter") -class DeepLabV3ImageSegmenter(Segmenter): +class DeepLabV3ImageSegmenter(ImageSegmenter): """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task. Args: diff --git a/keras_nlp/src/models/image_segmenter.py b/keras_nlp/src/models/image_segmenter.py index a4937d94ff..ba8ac16cf6 100644 --- a/keras_nlp/src/models/image_segmenter.py +++ b/keras_nlp/src/models/image_segmenter.py @@ -18,7 +18,7 @@ @keras_nlp_export("keras_nlp.models.Segmenter") -class Segmenter(Task): +class ImageSegmenter(Task): """Base class for all segmentation tasks. `Segmenter` tasks wrap a `keras_nlp.models.Backbone` to create a model From 64050d51fec5ab0e5dc17522244ac56762e9748e Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Tue, 17 Sep 2024 15:19:58 -0700 Subject: [PATCH 33/33] resolve conflict --- keras_nlp/api/models/__init__.py | 2 +- keras_nlp/src/models/image_segmenter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 6b3624ce91..988f6fe821 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -178,7 +178,7 @@ ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.image_classifier import ImageClassifier -from keras_nlp.src.models.image_segmenter import ImageSegmenter as Segmenter +from keras_nlp.src.models.image_segmenter import ImageSegmenter from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/image_segmenter.py b/keras_nlp/src/models/image_segmenter.py index ba8ac16cf6..70888dd52c 100644 --- a/keras_nlp/src/models/image_segmenter.py +++ b/keras_nlp/src/models/image_segmenter.py @@ -17,7 +17,7 @@ from keras_nlp.src.models.task import Task -@keras_nlp_export("keras_nlp.models.Segmenter") +@keras_nlp_export("keras_nlp.models.ImageSegmenter") class ImageSegmenter(Task): """Base class for all segmentation tasks.