|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def get_map_label_loss(opt): |
|
return MapLabelLoss(opt.label_loss_on_whole_map) |
|
|
|
|
|
class MapLabelLoss(nn.Module): |
|
def __init__(self, label_loss_on_whole_map=False): |
|
super().__init__() |
|
|
|
self.bce_loss = nn.BCELoss(reduction="none") |
|
self.label_loss_on_whole_map = label_loss_on_whole_map |
|
|
|
def forward(self, pred, out_map, label): |
|
batch_size = label.shape[0] |
|
if ( |
|
self.label_loss_on_whole_map |
|
): |
|
total_loss = 0 |
|
for i in range(batch_size): |
|
if label[i] == 0: |
|
total_loss = ( |
|
total_loss |
|
+ self.bce_loss(out_map[i, ...].mean(), label[i]).mean() |
|
) |
|
else: |
|
total_loss = total_loss + self.bce_loss(pred[i], label[i]).mean() |
|
loss = total_loss / batch_size |
|
else: |
|
loss = self.bce_loss(pred, label) |
|
loss = loss.mean() |
|
return {"loss": loss} |
|
|