WSCL / losses /map_label_loss.py
yhzhai's picture
release code
482ab8a
raw
history blame
1.12 kB
import torch
import torch.nn as nn
def get_map_label_loss(opt):
return MapLabelLoss(opt.label_loss_on_whole_map)
class MapLabelLoss(nn.Module):
def __init__(self, label_loss_on_whole_map=False):
super().__init__()
self.bce_loss = nn.BCELoss(reduction="none")
self.label_loss_on_whole_map = label_loss_on_whole_map
def forward(self, pred, out_map, label):
batch_size = label.shape[0]
if (
self.label_loss_on_whole_map
): # apply the loss on the whole map for pristine images
total_loss = 0
for i in range(batch_size):
if label[i] == 0: # pristine
total_loss = (
total_loss
+ self.bce_loss(out_map[i, ...].mean(), label[i]).mean()
)
else: # modified
total_loss = total_loss + self.bce_loss(pred[i], label[i]).mean()
loss = total_loss / batch_size
else:
loss = self.bce_loss(pred, label)
loss = loss.mean()
return {"loss": loss}