# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES @LOSSES.register_module() class BCELoss(nn.Module): """Binary Cross Entropy loss.""" def __init__(self, use_target_weight=False, loss_weight=1.): super().__init__() self.criterion = F.binary_cross_entropy self.use_target_weight = use_target_weight self.loss_weight = loss_weight def forward(self, output, target, target_weight=None): """Forward function. Note: - batch_size: N - num_labels: K Args: output (torch.Tensor[N, K]): Output classification. target (torch.Tensor[N, K]): Target classification. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. """ if self.use_target_weight: assert target_weight is not None loss = self.criterion(output, target, reduction='none') if target_weight.dim() == 1: target_weight = target_weight[:, None] loss = (loss * target_weight).mean() else: loss = self.criterion(output, target) return loss * self.loss_weight