# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from mmdet.registry import MODELS def ae_loss_per_image(tl_preds, br_preds, match): """Associative Embedding Loss in one image. Associative Embedding Loss including two parts: pull loss and push loss. Pull loss makes embedding vectors from same object closer to each other. Push loss distinguish embedding vector from different objects, and makes the gap between them is large enough. During computing, usually there are 3 cases: - no object in image: both pull loss and push loss will be 0. - one object in image: push loss will be 0 and pull loss is computed by the two corner of the only object. - more than one objects in image: pull loss is computed by corner pairs from each object, push loss is computed by each object with all other objects. We use confusion matrix with 0 in diagonal to compute the push loss. Args: tl_preds (tensor): Embedding feature map of left-top corner. br_preds (tensor): Embedding feature map of bottim-right corner. match (list): Downsampled coordinates pair of each ground truth box. """ tl_list, br_list, me_list = [], [], [] if len(match) == 0: # no object in image pull_loss = tl_preds.sum() * 0. push_loss = tl_preds.sum() * 0. else: for m in match: [tl_y, tl_x], [br_y, br_x] = m tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1) br_e = br_preds[:, br_y, br_x].view(-1, 1) tl_list.append(tl_e) br_list.append(br_e) me_list.append((tl_e + br_e) / 2.0) tl_list = torch.cat(tl_list) br_list = torch.cat(br_list) me_list = torch.cat(me_list) assert tl_list.size() == br_list.size() # N is object number in image, M is dimension of embedding vector N, M = tl_list.size() pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2) pull_loss = pull_loss.sum() / N margin = 1 # exp setting of CornerNet, details in section 3.3 of paper # confusion matrix of push loss conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list conf_weight = 1 - torch.eye(N).type_as(me_list) conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs()) if N > 1: # more than one object in current image push_loss = F.relu(conf_mat).sum() / (N * (N - 1)) else: push_loss = tl_preds.sum() * 0. return pull_loss, push_loss @MODELS.register_module() class AssociativeEmbeddingLoss(nn.Module): """Associative Embedding Loss. More details can be found in `Associative Embedding `_ and `CornerNet `_ . Code is modified from `kp_utils.py `_ # noqa: E501 Args: pull_weight (float): Loss weight for corners from same object. push_weight (float): Loss weight for corners from different object. """ def __init__(self, pull_weight=0.25, push_weight=0.25): super(AssociativeEmbeddingLoss, self).__init__() self.pull_weight = pull_weight self.push_weight = push_weight def forward(self, pred, target, match): """Forward function.""" batch = pred.size(0) pull_all, push_all = 0.0, 0.0 for i in range(batch): pull, push = ae_loss_per_image(pred[i], target[i], match[i]) pull_all += self.pull_weight * pull push_all += self.push_weight * push return pull_all, push_all