mart9992's picture
m
2cd560a
raw
history blame
1.28 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class BCELoss(nn.Module):
"""Binary Cross Entropy loss."""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = F.binary_cross_entropy
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target, reduction='none')
if target_weight.dim() == 1:
target_weight = target_weight[:, None]
loss = (loss * target_weight).mean()
else:
loss = self.criterion(output, target)
return loss * self.loss_weight