diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e079aa7c9e..f352a01b92 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -161,6 +161,12 @@ MistralPreprocessor, ) from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mix_transformer/__init__.py b/keras_nlp/src/models/mix_transformer/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py new file mode 100644 index 0000000000..2cfe7f6761 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -0,0 +1,181 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + HierarchicalTransformerEncoder, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + OverlappingPatchingAndEmbedding, +) + + +@keras_nlp_export("keras_nlp.models.MiTBackbone") +class MiTBackbone(FeaturePyramidBackbone): + def __init__( + self, + depths, + num_layers, + blockwise_num_heads, + blockwise_sr_ratios, + end_value, + patch_sizes, + strides, + include_rescaling=True, + image_shape=(224, 224, 3), + hidden_dims=None, + **kwargs, + ): + """A Backbone implementing the MixTransformer. + + This architecture to be used as a backbone for the SegFormer + architecture [SegFormer: Simple and Efficient Design for Semantic + Segmentation with Transformers](https://arxiv.org/abs/2105.15203) + [Based on the TensorFlow implementation from DeepVision]( + https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) + + Args: + depths: The number of transformer encoders to be used per layer in the + network. + num_layers: int. The number of Transformer layers. + blockwise_num_heads: list of integers, the number of heads to use + in the attention computation for each layer. + blockwise_sr_ratios: list of integers, the sequence reduction + ratio to perform for each layer on the sequence before key and + value projections. If set to > 1, a `Conv2D` layer is used to + reduce the length of the sequence. + end_value: The end value of the sequence. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + image_shape: optional shape tuple, defaults to (224, 224, 3). + hidden_dims: the embedding dims per hierarchical layer, used as + the levels of the feature pyramid. + patch_sizes: list of integers, the patch_size to apply for each layer. + strides: list of integers, stride to apply for each layer. + + Examples: + + Using the class with a `backbone`: + + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet") + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + dpr = [x for x in np.linspace(0.0, end_value, sum(depths))] + + # === Layers === + cur = 0 + patch_embedding_layers = [] + transformer_blocks = [] + layer_norms = [] + + for i in range(num_layers): + patch_embed_layer = OverlappingPatchingAndEmbedding( + project_dim=hidden_dims[i], + patch_size=patch_sizes[i], + stride=strides[i], + name=f"patch_and_embed_{i}", + ) + patch_embedding_layers.append(patch_embed_layer) + + transformer_block = [ + HierarchicalTransformerEncoder( + project_dim=hidden_dims[i], + num_heads=blockwise_num_heads[i], + sr_ratio=blockwise_sr_ratios[i], + drop_prob=dpr[cur + k], + name=f"hierarchical_encoder_{i}_{k}", + ) + for k in range(depths[i]) + ] + transformer_blocks.append(transformer_block) + cur += depths[i] + layer_norms.append(keras.layers.LayerNormalization()) + + # === Functional Model === + image_input = keras.layers.Input(shape=image_shape) + x = image_input + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + pyramid_outputs = {} + for i in range(num_layers): + # Compute new height/width after the `proj` + # call in `OverlappingPatchingAndEmbedding` + stride = strides[i] + new_height, new_width = ( + int(ops.shape(x)[1] / stride), + int(ops.shape(x)[2] / stride), + ) + + x = patch_embedding_layers[i](x) + for blk in transformer_blocks[i]: + x = blk(x) + x = layer_norms[i](x) + x = keras.layers.Reshape( + (new_height, new_width, -1), name=f"output_level_{i}" + )(x) + pyramid_outputs[f"P{i + 1}"] = x + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.depths = depths + self.include_rescaling = include_rescaling + self.image_shape = image_shape + self.hidden_dims = hidden_dims + self.pyramid_outputs = pyramid_outputs + self.num_layers = num_layers + self.blockwise_num_heads = blockwise_num_heads + self.blockwise_sr_ratios = blockwise_sr_ratios + self.end_value = end_value + self.patch_sizes = patch_sizes + self.strides = strides + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "include_rescaling": self.include_rescaling, + "hidden_dims": self.hidden_dims, + "image_shape": self.image_shape, + "num_layers": self.num_layers, + "blockwise_num_heads": self.blockwise_num_heads, + "blockwise_sr_ratios": self.blockwise_sr_ratios, + "end_value": self.end_value, + "patch_sizes": self.patch_sizes, + "strides": self.strides, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py new file mode 100644 index 0000000000..4f1955297f --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import models + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2], + "include_rescaling": True, + "image_shape": (16, 16, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], + "end_value": 0.1, + "patch_sizes": [7, 3], + "strides": [4, 2], + } + self.input_size = 16 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 2, 2, 8), + run_quantization_check=False, + run_mixed_precision_check=False, + ) + + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs + backbone = MiTBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P1", "P2"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** (int(k[1:]) + 1)) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py new file mode 100644 index 0000000000..a9a51b63ba --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) + + +@keras_nlp_export("keras_nlp.models.MiTImageClassifier") +class MiTImageClassifier(ImageClassifier): + """MiTImageClassifier image classifier model. + + Args: + backbone: A `keras_nlp.models.MiTBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.MiTBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.MiTImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MiTBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py new file mode 100644 index 0000000000..57b0671be2 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MiTBackbone( + depths=[2, 2, 2, 2], + include_rescaling=True, + image_shape=(16, 16, 3), + hidden_dims=[4, 8], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + end_value=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py new file mode 100644 index 0000000000..53d99fe484 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py @@ -0,0 +1,300 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras +from keras import ops +from keras import random + + +class OverlappingPatchingAndEmbedding(keras.layers.Layer): + def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): + """Overlapping Patching and Embedding layer. + + Differs from `PatchingAndEmbedding` in that the patch size does not + affect the sequence length. It's fully derived from the `stride` + parameter. Additionally, no positional embedding is done + as part of the layer - only a projection using a `Conv2D` layer. + + Args: + project_dim: integer, the dimensionality of the projection. + Defaults to `32`. + patch_size: integer, the size of the patches to encode. + Defaults to `7`. + stride: integer, the stride to use for the patching before + projection. Defaults to `5`. + """ + super().__init__(**kwargs) + + self.project_dim = project_dim + self.patch_size = patch_size + self.stride = stride + + self.proj = keras.layers.Conv2D( + filters=project_dim, + kernel_size=patch_size, + strides=stride, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + x = self.proj(x) + # B, H, W, C + shape = x.shape + x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = self.norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "patch_size": self.patch_size, + "stride": self.stride, + } + ) + return config + + +class HierarchicalTransformerEncoder(keras.layers.Layer): + """Hierarchical transformer encoder block implementation as a Keras Layer. + + The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` + alternative for computational efficiency, and is meant to be used + within the SegFormer architecture. + + Args: + project_dim: integer, the dimensionality of the projection of the + encoder, and output of the `SegFormerMultiheadAttention` layer. + Due to the residual addition the input dimensionality has to be + equal to the output dimensionality. + num_heads: integer, the number of heads for the + `SegFormerMultiheadAttention` layer. + drop_prob: float, the probability of dropping a random + sample using the `DropPath` layer. Defaults to `0.0`. + layer_norm_epsilon: float, the epsilon for + `LayerNormalization` layers. Defaults to `1e-06` + sr_ratio: integer, the ratio to use within + `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` + layer is used to reduce the length of the sequence. Defaults to `1`. + """ + + def __init__( + self, + project_dim, + num_heads, + sr_ratio=1, + drop_prob=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.num_heads = num_heads + self.drop_prop = drop_prob + + self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.attn = SegFormerMultiheadAttention( + project_dim, num_heads, sr_ratio + ) + self.drop_path = DropPath(drop_prob) + self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.mlp = MixFFN( + channels=project_dim, + mid_channels=int(project_dim * 4), + ) + + def build(self, input_shape): + super().build(input_shape) + self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) + self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) + + def call(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "mlp": keras.saving.serialize_keras_object(self.mlp), + "project_dim": self.project_dim, + "num_heads": self.num_heads, + "drop_prop": self.drop_prop, + } + ) + return config + + +class MixFFN(keras.layers.Layer): + def __init__(self, channels, mid_channels): + super().__init__() + self.fc1 = keras.layers.Dense(mid_channels) + self.dwconv = keras.layers.DepthwiseConv2D( + kernel_size=3, + strides=1, + padding="same", + ) + self.fc2 = keras.layers.Dense(channels) + + def call(self, x): + x = self.fc1(x) + shape = ops.shape(x) + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) + B, C = shape[0], shape[2] + x = ops.reshape(x, (B, H, W, C)) + x = self.dwconv(x) + x = ops.reshape(x, (B, -1, C)) + x = ops.nn.gelu(x) + x = self.fc2(x) + return x + + +class SegFormerMultiheadAttention(keras.layers.Layer): + def __init__(self, project_dim, num_heads, sr_ratio): + """Efficient MultiHeadAttention implementation as a Keras layer. + + A huge bottleneck in scaling transformers is the self-attention layer + with an O(n^2) complexity. + + SegFormerMultiheadAttention performs a sequence reduction (SR) operation + with a given ratio, to reduce the sequence length before performing key + and value projections, reducing the O(n^2) complexity to O(n^2/R) where + R is the sequence reduction ratio. + + Args: + project_dim: integer, the dimensionality of the projection + of the `SegFormerMultiheadAttention` layer. + num_heads: integer, the number of heads to use in the + attention computation. + sr_ratio: integer, the sequence reduction ratio to perform + on the sequence before key and value projections. + """ + super().__init__() + self.num_heads = num_heads + self.sr_ratio = sr_ratio + self.scale = (project_dim // num_heads) ** -0.5 + self.q = keras.layers.Dense(project_dim) + self.k = keras.layers.Dense(project_dim) + self.v = keras.layers.Dense(project_dim) + self.proj = keras.layers.Dense(project_dim) + + if sr_ratio > 1: + self.sr = keras.layers.Conv2D( + filters=project_dim, + kernel_size=sr_ratio, + strides=sr_ratio, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + input_shape = ops.shape(x) + H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) + B, C = input_shape[0], input_shape[2] + + q = self.q(x) + q = ops.reshape( + q, + ( + input_shape[0], + input_shape[1], + self.num_heads, + input_shape[2] // self.num_heads, + ), + ) + q = ops.transpose(q, [0, 2, 1, 3]) + + if self.sr_ratio > 1: + x = ops.reshape( + ops.transpose(x, [0, 2, 1]), + (B, H, W, C), + ) + x = self.sr(x) + x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) + x = ops.transpose(x, [0, 2, 1]) + x = self.norm(x) + + k = self.k(x) + v = self.v(x) + + k = ops.transpose( + ops.reshape( + k, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + v = ops.transpose( + ops.reshape( + v, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale + attn = ops.nn.softmax(attn, axis=-1) + + attn = attn @ v + attn = ops.reshape( + ops.transpose(attn, [0, 2, 1, 3]), + [input_shape[0], input_shape[1], input_shape[2]], + ) + + x = self.proj(attn) + return x + + +class DropPath(keras.layers.Layer): + """Implements the DropPath layer. + + DropPath randomly drops samples during + training with a probability of `rate`. Note that this layer drops individual + samples within a batch and not the entire batch, whereas StochasticDepth + randomly drops the entire batch. + + Args: + rate: float, the probability of the residual branch being dropped. + seed: (Optional) integer. Used to create a random seed. + """ + + def __init__(self, rate=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self._seed_val = seed + self.seed = random.SeedGenerator(seed=seed) + + def call(self, x, training=None): + if self.rate == 0.0 or not training: + return x + else: + batch_size = x.shape[0] or ops.shape(x)[0] + drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1) + drop_map = ops.cast( + random.uniform(drop_map_shape, seed=self.seed) > self.rate, + x.dtype, + ) + x = x / (1.0 - self.rate) + x = x * drop_map + return x + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate, "seed": self._seed_val}) + return config