diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 31698e0a1c..ca1de9b090 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone): This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( CVPR 2016), [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476)( - NeurIPS 2021 Workshop). + NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with + Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone): the batch normalization and ReLU activation are applied after the convolution layers. + ResNetVd introduces two key modifications to the standard ResNet. First, + the initial convolutional layer is replaced by a series of three + successive convolutional layers. Second, shortcut connections use an + additional pooling operation rather than performing downsampling within + the convolutional layers themselves. + Note that `ResNetBackbone` expects the inputs to be images with a value range of `[0, 255]` when `include_rescaling=True`. Args: + input_conv_filters: list of ints. The number of filters of the initial + convolution(s). + input_conv_kernel_sizes: list of ints. The kernel sizes of the initial + convolution(s). stackwise_num_filters: list of ints. The number of filters for each stack. stackwise_num_blocks: list of ints. The number of blocks for each stack. stackwise_num_strides: list of ints. The number of strides for each stack. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using @@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone): # Randomly initialized ResNetV2 backbone with a custom config. model = keras_nlp.models.ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], @@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone): def __init__( self, + input_conv_filters, + input_conv_kernel_sizes, stackwise_num_filters, stackwise_num_blocks, stackwise_num_strides, @@ -113,6 +131,13 @@ def __init__( dtype=None, **kwargs, ): + if len(input_conv_filters) != len(input_conv_kernel_sizes): + raise ValueError( + "The length of `input_conv_filters` and" + "`input_conv_kernel_sizes` must be the same. " + f"Received: input_conv_filters={input_conv_filters}, " + f"input_conv_kernel_sizes={input_conv_kernel_sizes}." + ) if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( stackwise_num_filters ) != len(stackwise_num_strides): @@ -128,14 +153,20 @@ def __init__( "The first element of `stackwise_num_filters` must be 64. " f"Received: stackwise_num_filters={stackwise_num_filters}" ) - if block_type not in ("basic_block", "bottleneck_block"): + if block_type not in ( + "basic_block", + "bottleneck_block", + "basic_block_vd", + "bottleneck_block_vd", + ): raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) - version = "v1" if not use_pre_activation else "v2" data_format = standardize_data_format(data_format) bn_axis = -1 if data_format == "channels_last" else 1 + num_input_convs = len(input_conv_filters) num_stacks = len(stackwise_num_filters) # === Functional Model === @@ -155,29 +186,56 @@ def __init__( # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( - 3, + (input_conv_kernel_sizes[0] - 1) // 2, data_format=data_format, dtype=dtype, name="conv1_pad", )(x) x = layers.Conv2D( - 64, - 7, + input_conv_filters[0], + input_conv_kernel_sizes[0], strides=2, data_format=data_format, use_bias=False, + padding="valid", dtype=dtype, name="conv1_conv", )(x) + for conv_index in range(1, num_input_convs): + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"conv{conv_index}_bn", + )(x) + x = layers.Activation( + "relu", dtype=dtype, name=f"conv{conv_index}_relu" + )(x) + x = layers.Conv2D( + input_conv_filters[conv_index], + input_conv_kernel_sizes[conv_index], + strides=1, + data_format=data_format, + use_bias=False, + padding="same", + dtype=dtype, + name=f"conv{conv_index+1}_conv", + )(x) + if not use_pre_activation: x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, momentum=0.9, dtype=dtype, - name="conv1_bn", + name=f"conv{num_input_convs}_bn", + )(x) + x = layers.Activation( + "relu", + dtype=dtype, + name=f"conv{num_input_convs}_relu", )(x) - x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) if use_pre_activation: # A workaround for ResNetV2: we need -inf padding to prevent zeros @@ -210,12 +268,10 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, - first_shortcut=( - block_type == "bottleneck_block" or stack_index > 0 - ), + first_shortcut=(block_type != "basic_block" or stack_index > 0), data_format=data_format, dtype=dtype, - name=f"{version}_stack{stack_index}", + name=f"stack{stack_index}", ) pyramid_outputs[f"P{stack_index + 2}"] = x @@ -248,6 +304,8 @@ def __init__( ) # === Config === + self.input_conv_filters = input_conv_filters + self.input_conv_kernel_sizes = input_conv_kernel_sizes self.stackwise_num_filters = stackwise_num_filters self.stackwise_num_blocks = stackwise_num_blocks self.stackwise_num_strides = stackwise_num_strides @@ -262,6 +320,8 @@ def get_config(self): config = super().get_config() config.update( { + "input_conv_filters": self.input_conv_filters, + "input_conv_kernel_sizes": self.input_conv_kernel_sizes, "stackwise_num_filters": self.stackwise_num_filters, "stackwise_num_blocks": self.stackwise_num_blocks, "stackwise_num_strides": self.stackwise_num_strides, @@ -327,7 +387,10 @@ def apply_basic_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( filters, 1, @@ -336,7 +399,7 @@ def apply_basic_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x) + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -452,7 +515,10 @@ def apply_bottleneck_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( 4 * filters, 1, @@ -461,7 +527,295 @@ def apply_bottleneck_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + x = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_basic_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -561,8 +915,11 @@ def apply_stack( blocks: int. The number of blocks in the stack. stride: int. The stride length of the first layer in the first block. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, @@ -580,17 +937,21 @@ def apply_stack( Output tensor for the stacked blocks. """ if name is None: - version = "v1" if not use_pre_activation else "v2" - name = f"{version}_stack" + name = "stack" if block_type == "basic_block": block_fn = apply_basic_block elif block_type == "bottleneck_block": block_fn = apply_bottleneck_block + elif block_type == "basic_block_vd": + block_fn = apply_basic_block_vd + elif block_type == "bottleneck_block_vd": + block_fn = apply_bottleneck_block_vd else: raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) for i in range(blocks): if i == 0: diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index a6a30362cd..f52800801f 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -24,6 +24,8 @@ class ResNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], @@ -38,18 +40,32 @@ def setUp(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) def test_backbone_basics(self, use_pre_activation, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( - {"block_type": block_type, "use_pre_activation": use_pre_activation} + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_vision_backbone_test( cls=ResNetBackbone, init_kwargs=init_kwargs, input_data=self.input_data, expected_output_shape=( - (2, 64) if block_type == "basic_block" else (2, 256) + (2, 64) + if block_type in ("basic_block", "basic_block_vd") + else (2, 256) ), ) @@ -76,6 +92,8 @@ def test_pyramid_output_format(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) @pytest.mark.large def test_saved_model(self, use_pre_activation, block_type): @@ -87,6 +105,13 @@ def test_saved_model(self, use_pre_activation, block_type): "image_shape": (None, None, 3), } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_model_saving_test( cls=ResNetBackbone, init_kwargs=init_kwargs, diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index 893ec42487..da06c80320 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -26,6 +26,8 @@ def setUp(self): self.images = ops.ones((2, 16, 16, 3)) self.labels = [0, 3] self.backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2],