File size: 6,689 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from functools import partial
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log
from torch import Tensor

from mmdet.registry import MODELS


@MODELS.register_module()
class EQLV2Loss(nn.Module):

    def __init__(self,
                 use_sigmoid: bool = True,
                 reduction: str = 'mean',
                 class_weight: Optional[Tensor] = None,
                 loss_weight: float = 1.0,
                 num_classes: int = 1203,
                 use_distributed: bool = False,
                 mu: float = 0.8,
                 alpha: float = 4.0,
                 gamma: int = 12,
                 vis_grad: bool = False,
                 test_with_obj: bool = True) -> None:
        """`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_

        Args:
            use_sigmoid (bool): EQLv2 uses the sigmoid function to transform
                the predicted logits to an estimated probability distribution.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'.
            class_weight (Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            loss_weight (float, optional): The weight of the total EQLv2 loss.
                Defaults to 1.0.
            num_classes (int): 1203 for lvis v1.0, 1230 for lvis v0.5.
            use_distributed (bool, float): EQLv2 will calculate the gradients
                on all GPUs if there is any. Change to True if you are using
                distributed training. Default to False.
            mu (float, optional): Defaults to 0.8
            alpha (float, optional): A balance factor for the negative part of
                EQLV2 Loss. Defaults to 4.0.
            gamma (int, optional): The gamma for calculating the modulating
                factor. Defaults to 12.
            vis_grad (bool, optional): Default to False.
            test_with_obj (bool, optional): Default to True.

        Returns:
            None.
        """
        super().__init__()
        self.use_sigmoid = True
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = class_weight
        self.num_classes = num_classes
        self.group = True

        # cfg for eqlv2
        self.vis_grad = vis_grad
        self.mu = mu
        self.alpha = alpha
        self.gamma = gamma
        self.use_distributed = use_distributed

        # initial variables
        self.register_buffer('pos_grad', torch.zeros(self.num_classes))
        self.register_buffer('neg_grad', torch.zeros(self.num_classes))
        # At the beginning of training, we set a high value (eg. 100)
        # for the initial gradient ratio so that the weight for pos
        # gradients and neg gradients are 1.
        self.register_buffer('pos_neg', torch.ones(self.num_classes) * 100)

        self.test_with_obj = test_with_obj

        def _func(x, gamma, mu):
            return 1 / (1 + torch.exp(-gamma * (x - mu)))

        self.map_func = partial(_func, gamma=self.gamma, mu=self.mu)

        print_log(
            f'build EQL v2, gamma: {gamma}, mu: {mu}, alpha: {alpha}',
            logger='current',
            level=logging.DEBUG)

    def forward(self,
                cls_score: Tensor,
                label: Tensor,
                weight: Optional[Tensor] = None,
                avg_factor: Optional[int] = None,
                reduction_override: Optional[Tensor] = None) -> Tensor:
        """`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_

        Args:
            cls_score (Tensor): The prediction with shape (N, C), C is the
                number of classes.
            label (Tensor): The ground truth label of the predicted target with
                shape (N, C), C is the number of classes.
            weight (Tensor, optional): The weight of loss for each prediction.
                Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Options are "none", "mean" and "sum".

        Returns:
           Tensor: The calculated loss
        """
        self.n_i, self.n_c = cls_score.size()
        self.gt_classes = label
        self.pred_class_logits = cls_score

        def expand_label(pred, gt_classes):
            target = pred.new_zeros(self.n_i, self.n_c)
            target[torch.arange(self.n_i), gt_classes] = 1
            return target

        target = expand_label(cls_score, label)

        pos_w, neg_w = self.get_weight(cls_score)

        weight = pos_w * target + neg_w * (1 - target)

        cls_loss = F.binary_cross_entropy_with_logits(
            cls_score, target, reduction='none')
        cls_loss = torch.sum(cls_loss * weight) / self.n_i

        self.collect_grad(cls_score.detach(), target.detach(), weight.detach())

        return self.loss_weight * cls_loss

    def get_channel_num(self, num_classes):
        num_channel = num_classes + 1
        return num_channel

    def get_activation(self, pred):
        pred = torch.sigmoid(pred)
        n_i, n_c = pred.size()
        bg_score = pred[:, -1].view(n_i, 1)
        if self.test_with_obj:
            pred[:, :-1] *= (1 - bg_score)
        return pred

    def collect_grad(self, pred, target, weight):
        prob = torch.sigmoid(pred)
        grad = target * (prob - 1) + (1 - target) * prob
        grad = torch.abs(grad)

        # do not collect grad for objectiveness branch [:-1]
        pos_grad = torch.sum(grad * target * weight, dim=0)[:-1]
        neg_grad = torch.sum(grad * (1 - target) * weight, dim=0)[:-1]

        if self.use_distributed:
            dist.all_reduce(pos_grad)
            dist.all_reduce(neg_grad)

        self.pos_grad += pos_grad
        self.neg_grad += neg_grad
        self.pos_neg = self.pos_grad / (self.neg_grad + 1e-10)

    def get_weight(self, pred):
        neg_w = torch.cat([self.map_func(self.pos_neg), pred.new_ones(1)])
        pos_w = 1 + self.alpha * (1 - neg_w)
        neg_w = neg_w.view(1, -1).expand(self.n_i, self.n_c)
        pos_w = pos_w.view(1, -1).expand(self.n_i, self.n_c)
        return pos_w, neg_w