Skip to content

Conversation

@munehiro-k
Copy link
Contributor

Hi, @qubvel. Thanks for your great work!

This PR fixes #377 and is used to be PR #561, which staled and closed before.

I ran into the issue #377 again to find that the PR can still be effective.
(My typical use case involves processing small images in real time, and a small encoder depth is preferred.)
So, I made some maintenance and rebased on edadc0d.

Update

  1. Fix the feature index mismatch which occurs when encoder_depth is 3 and 4.
    • Please refer to the attached file to see the combination of tensor shapes in each cases: tensor_shapes.md.
  2. Modify the docstring for upsampling argument to state the condition to preserve input-output shape.
    • In case (encoder_depth, encoder_output_stride) = (3, 16), upsampling should be set to 2.
  3. Modify a type hint and add a value check.
    • Type hint error: Squence[int, ...] should be Squence[int].
    • Value check: add a validation to make sure encoder_depth is either 3, 4, or 5.

Test Code

from itertools import product

import torch
import segmentation_models_pytorch as smp

input_shape = (10, 3, 192, 128)
input_tensor = torch.zeros(input_shape)
for up, depth, stride in product((2, 4), (3, 4, 5), (8, 16)):
    net = smp.DeepLabV3Plus(
        encoder_name="timm-mobilenetv3_small_minimal_100",
        encoder_weights="imagenet",
        encoder_depth=depth,
        encoder_output_stride=stride,
        upsampling=up,
        classes=1,
        activation=None
    )
    print(f"encoder_depth={depth}, encoder_output_stride={stride:2}, upsampling={up}")
    output_shape = tuple(net(input_tensor).shape)
    preserved = all(input_shape[i] == output_shape[i] for i in (0, 2, 3))
    print(f"  output shape: {output_shape}, preserve shape: {preserved}")

output

encoder_depth=3, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=4, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=3, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 384, 256), preserve shape: False
encoder_depth=4, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True

@brianhou0208
Copy link
Contributor

brianhou0208 commented Nov 29, 2024

Hi @munehiro-k ,

I also found this problem. I think I can contribute another PR and solve problems about DeeplabV3 and Deeplab發V3+ in different encoder depth and output stride

@qubvel qubvel self-requested a review November 29, 2024 18:01
Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, it looks great to me! Special thanks for providing the testing code sample. 🤗

@qubvel qubvel merged commit cc482aa into qubvel-org:main Nov 29, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

encoder_depth error

3 participants