3v324v23's picture
add
c310e19
raw
history blame
No virus
2.12 kB
#!/usr/bin/env python3
"""
This file contains specific functions for computing losses on the SEG
file
"""
import torch
class SEGLossComputation(object):
"""
This class computes the SEG loss.
"""
def __init__(self, cfg):
self.eps = 1e-6
self.cfg = cfg
def __call__(self, preds, targets):
"""
Arguments:
preds (Tensor)
targets (list[Tensor])
masks (list[Tensor])
Returns:
seg_loss (Tensor)
"""
image_size = (preds.shape[2], preds.shape[3])
segm_targets, masks = self.prepare_targets(targets, image_size)
device = preds.device
segm_targets = segm_targets.float().to(device)
masks = masks.float().to(device)
seg_loss = self.dice_loss(preds, segm_targets, masks)
return seg_loss
def dice_loss(self, pred, gt, m):
intersection = torch.sum(pred * gt * m)
union = torch.sum(pred * m) + torch.sum(gt * m) + self.eps
loss = 1 - 2.0 * intersection / union
return loss
def project_masks_on_image(self, mask_polygons, labels, shrink_ratio, image_size):
seg_map, training_mask = mask_polygons.convert_seg_map(
labels, shrink_ratio, image_size, self.cfg.MODEL.SEG.IGNORE_DIFFICULT
)
return torch.from_numpy(seg_map), torch.from_numpy(training_mask)
def prepare_targets(self, targets, image_size):
segms = []
training_masks = []
for target_per_image in targets:
segmentation_masks = target_per_image.get_field("masks")
labels = target_per_image.get_field("labels")
seg_maps_per_image, training_masks_per_image = self.project_masks_on_image(
segmentation_masks, labels, self.cfg.MODEL.SEG.SHRINK_RATIO, image_size
)
segms.append(seg_maps_per_image)
training_masks.append(training_masks_per_image)
return torch.stack(segms), torch.stack(training_masks)
def make_seg_loss_evaluator(cfg):
loss_evaluator = SEGLossComputation(cfg)
return loss_evaluator