| 
 | 1 | +import keras  | 
 | 2 | + | 
 | 3 | +from keras_hub.src.api_export import keras_hub_export  | 
 | 4 | +from keras_hub.src.models.backbone import Backbone  | 
 | 5 | + | 
 | 6 | + | 
 | 7 | +@keras_hub_export("keras_hub.models.SegFormerBackbone")  | 
 | 8 | +class SegFormerBackbone(Backbone):  | 
 | 9 | +    """A Keras model implementing the SegFormer architecture for semantic segmentation.  | 
 | 10 | +
  | 
 | 11 | +    This class implements the majority of the SegFormer architecture described in  | 
 | 12 | +    [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers]  | 
 | 13 | +    (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision]  | 
 | 14 | +    (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).  | 
 | 15 | +
  | 
 | 16 | +    SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and  | 
 | 17 | +    and use a very lightweight all-MLP decoder head.  | 
 | 18 | +
  | 
 | 19 | +    The MiT encoder uses a hierarchical transformer which outputs features at multiple scales,  | 
 | 20 | +    similar to that of the hierarchical outputs typically associated with CNNs.  | 
 | 21 | +
  | 
 | 22 | +    Args:  | 
 | 23 | +        image_encoder: `keras.Model`. The backbone network for the model that is  | 
 | 24 | +            used as a feature extractor for the SegFormer encoder.  | 
 | 25 | +            Should be used with the MiT backbone model  | 
 | 26 | +            (`keras_hub.models.MiTBackbone`) which was created  | 
 | 27 | +            specifically for SegFormers.  | 
 | 28 | +        num_classes: int, the number of classes for the detection model,  | 
 | 29 | +            including the background class.  | 
 | 30 | +        projection_filters: int, number of filters in the  | 
 | 31 | +            convolution layer projecting the concatenated features into  | 
 | 32 | +            a segmentation map. Defaults to 256`.  | 
 | 33 | +
  | 
 | 34 | +    Example:  | 
 | 35 | +
  | 
 | 36 | +    Using the class with a custom `backbone`:  | 
 | 37 | +
  | 
 | 38 | +    ```python  | 
 | 39 | +    import keras_hub  | 
 | 40 | +
  | 
 | 41 | +    backbone = keras_hub.models.MiTBackbone(  | 
 | 42 | +        depths=[2, 2, 2, 2],  | 
 | 43 | +        image_shape=(224, 224, 3),  | 
 | 44 | +        hidden_dims=[32, 64, 160, 256],  | 
 | 45 | +        num_layers=4,  | 
 | 46 | +        blockwise_num_heads=[1, 2, 5, 8],  | 
 | 47 | +        blockwise_sr_ratios=[8, 4, 2, 1],  | 
 | 48 | +        max_drop_path_rate=0.1,  | 
 | 49 | +        patch_sizes=[7, 3, 3, 3],  | 
 | 50 | +        strides=[4, 2, 2, 2],  | 
 | 51 | +    )  | 
 | 52 | +
  | 
 | 53 | +    segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)  | 
 | 54 | +    ```  | 
 | 55 | +
  | 
 | 56 | +    Using the class with a preset `backbone`:  | 
 | 57 | +
  | 
 | 58 | +    ```python  | 
 | 59 | +    import keras_hub  | 
 | 60 | +
  | 
 | 61 | +    backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")  | 
 | 62 | +    segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)  | 
 | 63 | +    ```  | 
 | 64 | +
  | 
 | 65 | +    """  | 
 | 66 | + | 
 | 67 | +    def __init__(  | 
 | 68 | +        self,  | 
 | 69 | +        image_encoder,  | 
 | 70 | +        projection_filters,  | 
 | 71 | +        **kwargs,  | 
 | 72 | +    ):  | 
 | 73 | +        if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(  | 
 | 74 | +            image_encoder, keras.Model  | 
 | 75 | +        ):  | 
 | 76 | +            raise ValueError(  | 
 | 77 | +                "Argument `image_encoder` must be a `keras.layers.Layer` instance "  | 
 | 78 | +                f" or `keras.Model`. Received instead "  | 
 | 79 | +                f"image_encoder={image_encoder} (of type {type(image_encoder)})."  | 
 | 80 | +            )  | 
 | 81 | + | 
 | 82 | +        # === Layers ===  | 
 | 83 | +        inputs = keras.layers.Input(shape=image_encoder.input.shape[1:])  | 
 | 84 | + | 
 | 85 | +        self.feature_extractor = keras.Model(  | 
 | 86 | +            image_encoder.inputs, image_encoder.pyramid_outputs  | 
 | 87 | +        )  | 
 | 88 | + | 
 | 89 | +        features = self.feature_extractor(inputs)  | 
 | 90 | +        # Get height and width of level one output  | 
 | 91 | +        _, height, width, _ = features["P1"].shape  | 
 | 92 | + | 
 | 93 | +        self.mlp_blocks = []  | 
 | 94 | + | 
 | 95 | +        for feature_dim, feature in zip(image_encoder.hidden_dims, features):  | 
 | 96 | +            self.mlp_blocks.append(  | 
 | 97 | +                keras.layers.Dense(  | 
 | 98 | +                    projection_filters, name=f"linear_{feature_dim}"  | 
 | 99 | +                )  | 
 | 100 | +            )  | 
 | 101 | + | 
 | 102 | +        self.resizing = keras.layers.Resizing(  | 
 | 103 | +            height, width, interpolation="bilinear"  | 
 | 104 | +        )  | 
 | 105 | +        self.concat = keras.layers.Concatenate(axis=-1)  | 
 | 106 | +        self.linear_fuse = keras.Sequential(  | 
 | 107 | +            [  | 
 | 108 | +                keras.layers.Conv2D(  | 
 | 109 | +                    filters=projection_filters, kernel_size=1, use_bias=False  | 
 | 110 | +                ),  | 
 | 111 | +                keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9),  | 
 | 112 | +                keras.layers.Activation("relu"),  | 
 | 113 | +            ]  | 
 | 114 | +        )  | 
 | 115 | + | 
 | 116 | +        # === Functional Model ===  | 
 | 117 | +        # Project all multi-level outputs onto  | 
 | 118 | +        # the same dimensionality and feature map shape  | 
 | 119 | +        multi_layer_outs = []  | 
 | 120 | +        for index, (feature_dim, feature) in enumerate(  | 
 | 121 | +            zip(image_encoder.hidden_dims, features)  | 
 | 122 | +        ):  | 
 | 123 | +            out = self.mlp_blocks[index](features[feature])  | 
 | 124 | +            out = self.resizing(out)  | 
 | 125 | +            multi_layer_outs.append(out)  | 
 | 126 | + | 
 | 127 | +        # Concat now-equal feature maps  | 
 | 128 | +        concatenated_outs = self.concat(multi_layer_outs[::-1])  | 
 | 129 | + | 
 | 130 | +        # Fuse concatenated features into a segmentation map  | 
 | 131 | +        seg = self.linear_fuse(concatenated_outs)  | 
 | 132 | + | 
 | 133 | +        super().__init__(  | 
 | 134 | +            inputs=inputs,  | 
 | 135 | +            outputs=seg,  | 
 | 136 | +            **kwargs,  | 
 | 137 | +        )  | 
 | 138 | + | 
 | 139 | +        # === Config ===  | 
 | 140 | +        self.projection_filters = projection_filters  | 
 | 141 | +        self.image_encoder = image_encoder  | 
 | 142 | + | 
 | 143 | +    def get_config(self):  | 
 | 144 | +        config = super().get_config()  | 
 | 145 | +        config.update(  | 
 | 146 | +            {  | 
 | 147 | +                "projection_filters": self.projection_filters,  | 
 | 148 | +                "image_encoder": keras.saving.serialize_keras_object(  | 
 | 149 | +                    self.image_encoder  | 
 | 150 | +                ),  | 
 | 151 | +            }  | 
 | 152 | +        )  | 
 | 153 | +        return config  | 
 | 154 | + | 
 | 155 | +    @classmethod  | 
 | 156 | +    def from_config(cls, config):  | 
 | 157 | +        if "image_encoder" in config and isinstance(  | 
 | 158 | +            config["image_encoder"], dict  | 
 | 159 | +        ):  | 
 | 160 | +            config["image_encoder"] = keras.layers.deserialize(  | 
 | 161 | +                config["image_encoder"]  | 
 | 162 | +            )  | 
 | 163 | +        return super().from_config(config)  | 
0 commit comments