diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py index ed50cff21a5..9c7cc3f0f8e 100644 --- a/mmdet/models/losses/focal_loss.py +++ b/mmdet/models/losses/focal_loss.py @@ -72,6 +72,7 @@ def py_focal_loss_with_prob(pred, pred (torch.Tensor): The prediction probability with shape (N, C), C is the number of classes. target (torch.Tensor): The learning label of the prediction. + The target shape support (N,C) or (N,), (N,C) means one-hot form. weight (torch.Tensor, optional): Sample-wise loss weight. gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. @@ -82,9 +83,10 @@ def py_focal_loss_with_prob(pred, avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ - num_classes = pred.size(1) - target = F.one_hot(target, num_classes=num_classes + 1) - target = target[:, :num_classes] + if pred.dim() != target.dim(): + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] target = target.type_as(pred) pt = (1 - pred) * target + pred * (1 - target) @@ -204,6 +206,8 @@ def forward(self, Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning label of the prediction. + The target shape support (N,C) or (N,), (N,C) means + one-hot form. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average @@ -222,7 +226,10 @@ def forward(self, if self.activated: calculate_loss_func = py_focal_loss_with_prob else: - if torch.cuda.is_available() and pred.is_cuda: + if pred.dim() == target.dim(): + # this means that target is already in One-Hot form. + calculate_loss_func = py_sigmoid_focal_loss + elif torch.cuda.is_available() and pred.is_cuda: calculate_loss_func = sigmoid_focal_loss else: num_classes = pred.size(1) diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py index eb1c2401758..b3a1172207e 100644 --- a/mmdet/models/losses/gfocal_loss.py +++ b/mmdet/models/losses/gfocal_loss.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch import torch.nn as nn import torch.nn.functional as F +from mmdet.models.losses.utils import weighted_loss from mmdet.registry import MODELS -from .utils import weighted_loss @weighted_loss @@ -50,6 +53,47 @@ def quality_focal_loss(pred, target, beta=2.0): return loss +@weighted_loss +def quality_focal_loss_tensor_target(pred, target, beta=2.0, activated=False): + """`QualityFocal Loss `_ + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning target of the iou-aware + classification score with shape (N, C), C is the number of classes. + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + activated (bool): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. + """ + # pred and target should be of the same size + assert pred.size() == target.size() + if activated: + pred_sigmoid = pred + loss_function = F.binary_cross_entropy + else: + pred_sigmoid = pred.sigmoid() + loss_function = F.binary_cross_entropy_with_logits + + scale_factor = pred_sigmoid + target = target.type_as(pred) + + zerolabel = scale_factor.new_zeros(pred.shape) + loss = loss_function( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + pos = (target != 0) + scale_factor = target[pos] - pred_sigmoid[pos] + loss[pos] = loss_function( + pred[pos], target[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + @weighted_loss def quality_focal_loss_with_prob(pred, target, beta=2.0): r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning @@ -166,8 +210,11 @@ def forward(self, pred (torch.Tensor): Predicted joint representation of classification and quality (IoU) estimation with shape (N, C), C is the number of classes. - target (tuple([torch.Tensor])): Target category label with shape - (N,) and target quality label with shape (N,). + target (Union(tuple([torch.Tensor]),Torch.Tensor)): The type is + tuple, it should be included Target category label with + shape (N,) and target quality label with shape (N,).The type + is torch.Tensor, the target should be one-hot form with + soft weights. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average @@ -184,6 +231,12 @@ def forward(self, calculate_loss_func = quality_focal_loss_with_prob else: calculate_loss_func = quality_focal_loss + if isinstance(target, torch.Tensor): + # the target shape with (N,C) or (N,C,...), which means + # the target is one-hot form with soft weights. + calculate_loss_func = partial( + quality_focal_loss_tensor_target, activated=self.activated) + loss_cls = self.loss_weight * calculate_loss_func( pred, target, diff --git a/tests/test_models/test_losses/test_loss.py b/tests/test_models/test_losses/test_loss.py index 166cc85fa06..040589012c4 100644 --- a/tests/test_models/test_losses/test_loss.py +++ b/tests/test_models/test_losses/test_loss.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +import torch.nn.functional as F from mmengine.utils import digit_version from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss, @@ -29,7 +30,7 @@ def test_iou_type_loss_zeros_weight(loss_class): @pytest.mark.parametrize('loss_class', [ BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss, EIoULoss, FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss, - GaussianFocalLoss, GIoULoss, IoULoss, L1Loss, QualityFocalLoss, + GaussianFocalLoss, GIoULoss, QualityFocalLoss, IoULoss, L1Loss, VarifocalLoss, GHMR, GHMC, SmoothL1Loss, KnowledgeDistillationKLDivLoss, DiceLoss ]) @@ -46,6 +47,26 @@ def test_loss_with_reduction_override(loss_class): pred, target, weight, reduction_override=reduction_override) +@pytest.mark.parametrize('loss_class', [QualityFocalLoss]) +@pytest.mark.parametrize('activated', [False, True]) +def test_QualityFocalLoss_Loss(loss_class, activated): + input_shape = (4, 5) + pred = torch.rand(input_shape) + label = torch.Tensor([0, 1, 2, 0]).long() + quality_label = torch.rand(input_shape[0]) + + original_loss = loss_class(activated=activated)(pred, + (label, quality_label)) + assert isinstance(original_loss, torch.Tensor) + + target = torch.nn.functional.one_hot(label, 5) + target = target * quality_label.reshape(input_shape[0], 1) + + new_loss = loss_class(activated=activated)(pred, target) + assert isinstance(new_loss, torch.Tensor) + assert new_loss == original_loss + + @pytest.mark.parametrize('loss_class', [ IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss, MSELoss, L1Loss, SmoothL1Loss, BalancedL1Loss @@ -86,7 +107,7 @@ def test_regression_losses(loss_class, input_shape): assert isinstance(loss, torch.Tensor) -@pytest.mark.parametrize('loss_class', [FocalLoss, CrossEntropyLoss]) +@pytest.mark.parametrize('loss_class', [CrossEntropyLoss]) @pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)]) def test_classification_losses(loss_class, input_shape): if input_shape[0] == 0 and digit_version( @@ -124,6 +145,42 @@ def test_classification_losses(loss_class, input_shape): assert isinstance(loss, torch.Tensor) +@pytest.mark.parametrize('loss_class', [FocalLoss]) +@pytest.mark.parametrize('input_shape', [(10, 5), (3, 5, 40, 40)]) +def test_FocalLoss_loss(loss_class, input_shape): + pred = torch.rand(input_shape) + target = torch.randint(0, 5, (input_shape[0], )) + if len(input_shape) == 4: + B, N, W, H = input_shape + target = F.one_hot(torch.randint(0, 5, (B * W * H, )), + 5).reshape(B, W, H, N).permute(0, 3, 1, 2) + + # Test loss forward + loss = loss_class()(pred, target) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with reduction_override + loss = loss_class()(pred, target, reduction_override='mean') + assert isinstance(loss, torch.Tensor) + + # Test loss forward with avg_factor + loss = loss_class()(pred, target, avg_factor=10) + assert isinstance(loss, torch.Tensor) + + with pytest.raises(ValueError): + # loss can evaluate with avg_factor only if + # reduction is None, 'none' or 'mean'. + reduction_override = 'sum' + loss_class()( + pred, target, avg_factor=10, reduction_override=reduction_override) + + # Test loss forward with avg_factor and reduction + for reduction_override in [None, 'none', 'mean']: + loss_class()( + pred, target, avg_factor=10, reduction_override=reduction_override) + assert isinstance(loss, torch.Tensor) + + @pytest.mark.parametrize('loss_class', [GHMR]) @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)]) def test_GHMR_loss(loss_class, input_shape):