import torch import numpy as np from torch.nn.modules import loss class WARPLoss(loss.Module): def __init__(self, num_labels=204): super(WARPLoss, self).__init__() self.rank_weights = [1.0 / 1] for i in range(1, num_labels): self.rank_weights.append(self.rank_weights[i - 1] + (1.0 / i + 1)) def forward(self, input, target) -> object: """ :rtype: :param input: Deep features tensor Variable of size batch x n_attrs. :param target: Ground truth tensor Variable of size batch x n_attrs. :return: """ batch_size = target.size()[0] n_labels = target.size()[1] max_num_trials = n_labels - 1 loss = 0.0 for i in range(batch_size): for j in range(n_labels): if target[i, j] == 1: neg_labels_idx = np.array([idx for idx, v in enumerate(target[i, :]) if v == 0]) neg_idx = np.random.choice(neg_labels_idx, replace=False) sample_score_margin = 1 - input[i, j] + input[i, neg_idx] num_trials = 0 while sample_score_margin < 0 and num_trials < max_num_trials: neg_idx = np.random.choice(neg_labels_idx, replace=False) num_trials += 1 sample_score_margin = 1 - input[i, j] + input[i, neg_idx] r_j = np.floor(max_num_trials / num_trials) weight = self.rank_weights[r_j] for k in range(n_labels): if target[i, k] == 0: score_margin = 1 - input[i, j] + input[i, k] loss += (weight * torch.clamp(score_margin, min=0.0)) return loss class MultiLabelSoftmaxRegressionLoss(loss.Module): def __init__(self): super(MultiLabelSoftmaxRegressionLoss, self).__init__() def forward(self, input, target) -> object: return -1 * torch.sum(input * target) class LossFactory(object): def __init__(self, type, num_labels=156): self.type = type if type == 'BCE': # self.activation_func = torch.nn.Sigmoid() self.loss = torch.nn.BCELoss() elif type == 'CE': self.loss = torch.nn.CrossEntropyLoss() elif type == 'WARP': self.activation_func = torch.nn.Softmax() self.loss = WARPLoss(num_labels=num_labels) elif type == 'MSR': self.activation_func = torch.nn.LogSoftmax() self.loss = MultiLabelSoftmaxRegressionLoss() def compute_loss(self, output, target): # output = self.activation_func(output) # if self.type == 'NLL' or self.type == 'WARP' or self.type == 'MSR': # target /= torch.sum(target, 1).view(-1, 1) return self.loss(output, target)