diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index ea7232e6..01a6129f 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -55,6 +55,7 @@ class DeepLabV3PlusDecoder(nn.Module): def __init__( self, encoder_channels, + encoder_depth=5, out_channels=256, atrous_rates=(12, 24, 36), output_stride=16, @@ -76,7 +77,14 @@ def __init__( scale_factor = 2 if output_stride == 8 else 4 self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) - highres_in_channels = encoder_channels[-4] + if encoder_depth == 3 and output_stride == 8: + self.highres_input_index = -2 + elif encoder_depth == 3 or encoder_depth == 4: + self.highres_input_index = -3 + else: + self.highres_input_index = -4 + + highres_in_channels = encoder_channels[self.highres_input_index] highres_out_channels = 48 # proposed by authors of paper self.block1 = nn.Sequential( nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False), @@ -98,7 +106,7 @@ def __init__( def forward(self, *features): aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) - high_res_features = self.block1(features[-4]) + high_res_features = self.block1(features[self.highres_input_index]) concat_features = torch.cat([aspp_features, high_res_features], dim=1) fused_features = self.block2(concat_features) return fused_features diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index a88364df..d1b15fbd 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -108,7 +108,8 @@ class DeepLabV3Plus(SegmentationModel): Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case + **encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -153,6 +154,7 @@ def __init__( self.decoder = DeepLabV3PlusDecoder( encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride,