|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn as nn |
|
from torch.autograd import Variable |
|
from torch.nn import MSELoss, SmoothL1Loss, L1Loss |
|
|
|
from pytorch3dunet.unet3d.utils import expand_as_one_hot |
|
|
|
|
|
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): |
|
""" |
|
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. |
|
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. |
|
|
|
Args: |
|
input (torch.Tensor): NxCxSpatial input tensor |
|
target (torch.Tensor): NxCxSpatial target tensor |
|
epsilon (float): prevents division by zero |
|
weight (torch.Tensor): Cx1 tensor of weight per channel/class |
|
""" |
|
|
|
|
|
assert input.size() == target.size(), "'input' and 'target' must have the same shape" |
|
|
|
input = flatten(input) |
|
target = flatten(target) |
|
target = target.float() |
|
|
|
|
|
intersect = (input * target).sum(-1) |
|
if weight is not None: |
|
intersect = weight * intersect |
|
|
|
|
|
denominator = (input * input).sum(-1) + (target * target).sum(-1) |
|
return 2 * (intersect / denominator.clamp(min=epsilon)) |
|
|
|
|
|
class _MaskingLossWrapper(nn.Module): |
|
""" |
|
Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. |
|
""" |
|
|
|
def __init__(self, loss, ignore_index): |
|
super(_MaskingLossWrapper, self).__init__() |
|
assert ignore_index is not None, 'ignore_index cannot be None' |
|
self.loss = loss |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, input, target): |
|
mask = target.clone().ne_(self.ignore_index) |
|
mask.requires_grad = False |
|
|
|
|
|
input = input * mask |
|
target = target * mask |
|
|
|
|
|
return self.loss(input, target) |
|
|
|
|
|
class SkipLastTargetChannelWrapper(nn.Module): |
|
""" |
|
Loss wrapper which removes additional target channel |
|
""" |
|
|
|
def __init__(self, loss, squeeze_channel=False): |
|
super(SkipLastTargetChannelWrapper, self).__init__() |
|
self.loss = loss |
|
self.squeeze_channel = squeeze_channel |
|
|
|
def forward(self, input, target): |
|
assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' |
|
|
|
|
|
target = target[:, :-1, ...] |
|
|
|
if self.squeeze_channel: |
|
|
|
target = torch.squeeze(target, dim=1) |
|
return self.loss(input, target) |
|
|
|
|
|
class _AbstractDiceLoss(nn.Module): |
|
""" |
|
Base class for different implementations of Dice loss. |
|
""" |
|
|
|
def __init__(self, weight=None, normalization='sigmoid'): |
|
super(_AbstractDiceLoss, self).__init__() |
|
self.register_buffer('weight', weight) |
|
|
|
|
|
|
|
|
|
|
|
assert normalization in ['sigmoid', 'softmax', 'none'] |
|
if normalization == 'sigmoid': |
|
self.normalization = nn.Sigmoid() |
|
elif normalization == 'softmax': |
|
self.normalization = nn.Softmax(dim=1) |
|
else: |
|
self.normalization = lambda x: x |
|
|
|
def dice(self, input, target, weight): |
|
|
|
raise NotImplementedError |
|
|
|
def forward(self, input, target): |
|
|
|
input = self.normalization(input) |
|
|
|
|
|
per_channel_dice = self.dice(input, target, weight=self.weight) |
|
|
|
|
|
return 1. - torch.mean(per_channel_dice) |
|
|
|
|
|
class DiceLoss(_AbstractDiceLoss): |
|
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797. |
|
For multi-class segmentation `weight` parameter can be used to assign different weights per class. |
|
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. |
|
""" |
|
|
|
def __init__(self, weight=None, normalization='sigmoid'): |
|
super().__init__(weight, normalization) |
|
|
|
def dice(self, input, target, weight): |
|
return compute_per_channel_dice(input, target, weight=self.weight) |
|
|
|
|
|
class GeneralizedDiceLoss(_AbstractDiceLoss): |
|
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. |
|
""" |
|
|
|
def __init__(self, normalization='sigmoid', epsilon=1e-6): |
|
super().__init__(weight=None, normalization=normalization) |
|
self.epsilon = epsilon |
|
|
|
def dice(self, input, target, weight): |
|
assert input.size() == target.size(), "'input' and 'target' must have the same shape" |
|
|
|
input = flatten(input) |
|
target = flatten(target) |
|
target = target.float() |
|
|
|
if input.size(0) == 1: |
|
|
|
|
|
input = torch.cat((input, 1 - input), dim=0) |
|
target = torch.cat((target, 1 - target), dim=0) |
|
|
|
|
|
w_l = target.sum(-1) |
|
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) |
|
w_l.requires_grad = False |
|
|
|
intersect = (input * target).sum(-1) |
|
intersect = intersect * w_l |
|
|
|
denominator = (input + target).sum(-1) |
|
denominator = (denominator * w_l).clamp(min=self.epsilon) |
|
|
|
return 2 * (intersect.sum() / denominator.sum()) |
|
|
|
|
|
class BCEDiceLoss(nn.Module): |
|
"""Linear combination of BCE and Dice losses""" |
|
|
|
def __init__(self, alpha, beta): |
|
super(BCEDiceLoss, self).__init__() |
|
self.alpha = alpha |
|
self.bce = nn.BCEWithLogitsLoss() |
|
self.beta = beta |
|
self.dice = DiceLoss() |
|
|
|
def forward(self, input, target): |
|
return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) |
|
|
|
|
|
class WeightedCrossEntropyLoss(nn.Module): |
|
"""WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf |
|
""" |
|
|
|
def __init__(self, ignore_index=-1): |
|
super(WeightedCrossEntropyLoss, self).__init__() |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, input, target): |
|
weight = self._class_weights(input) |
|
return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) |
|
|
|
@staticmethod |
|
def _class_weights(input): |
|
|
|
input = F.softmax(input, dim=1) |
|
flattened = flatten(input) |
|
nominator = (1. - flattened).sum(-1) |
|
denominator = flattened.sum(-1) |
|
class_weights = Variable(nominator / denominator, requires_grad=False) |
|
return class_weights |
|
|
|
|
|
class PixelWiseCrossEntropyLoss(nn.Module): |
|
def __init__(self, class_weights=None, ignore_index=None): |
|
super(PixelWiseCrossEntropyLoss, self).__init__() |
|
self.register_buffer('class_weights', class_weights) |
|
self.ignore_index = ignore_index |
|
self.log_softmax = nn.LogSoftmax(dim=1) |
|
|
|
def forward(self, input, target, weights): |
|
assert target.size() == weights.size() |
|
|
|
log_probabilities = self.log_softmax(input) |
|
|
|
target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) |
|
|
|
weights = weights.unsqueeze(1) |
|
weights = weights.expand_as(input) |
|
|
|
|
|
if self.class_weights is None: |
|
class_weights = torch.ones(input.size()[1]).float().cuda() |
|
else: |
|
class_weights = self.class_weights |
|
|
|
|
|
class_weights = class_weights.view(1, -1, 1, 1, 1) |
|
|
|
|
|
weights = class_weights * weights |
|
|
|
|
|
result = -weights * target * log_probabilities |
|
|
|
return result.mean() |
|
|
|
|
|
class WeightedSmoothL1Loss(nn.SmoothL1Loss): |
|
def __init__(self, threshold, initial_weight, apply_below_threshold=True): |
|
super().__init__(reduction="none") |
|
self.threshold = threshold |
|
self.apply_below_threshold = apply_below_threshold |
|
self.weight = initial_weight |
|
|
|
def forward(self, input, target): |
|
l1 = super().forward(input, target) |
|
|
|
if self.apply_below_threshold: |
|
mask = target < self.threshold |
|
else: |
|
mask = target >= self.threshold |
|
|
|
l1[mask] = l1[mask] * self.weight |
|
|
|
return l1.mean() |
|
|
|
|
|
def flatten(tensor): |
|
"""Flattens a given tensor such that the channel axis is first. |
|
The shapes are transformed as follows: |
|
(N, C, D, H, W) -> (C, N * D * H * W) |
|
""" |
|
|
|
C = tensor.size(1) |
|
|
|
axis_order = (1, 0) + tuple(range(2, tensor.dim())) |
|
|
|
transposed = tensor.permute(axis_order) |
|
|
|
return transposed.contiguous().view(C, -1) |
|
|
|
|
|
def get_loss_criterion(config): |
|
""" |
|
Returns the loss function based on provided configuration |
|
:param config: (dict) a top level configuration object containing the 'loss' key |
|
:return: an instance of the loss function |
|
""" |
|
assert 'loss' in config, 'Could not find loss function configuration' |
|
loss_config = config['loss'] |
|
name = loss_config.pop('name') |
|
|
|
ignore_index = loss_config.pop('ignore_index', None) |
|
skip_last_target = loss_config.pop('skip_last_target', False) |
|
weight = loss_config.pop('weight', None) |
|
|
|
if weight is not None: |
|
weight = torch.tensor(weight) |
|
|
|
pos_weight = loss_config.pop('pos_weight', None) |
|
if pos_weight is not None: |
|
pos_weight = torch.tensor(pos_weight) |
|
|
|
loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) |
|
|
|
if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): |
|
|
|
loss = _MaskingLossWrapper(loss, ignore_index) |
|
|
|
if skip_last_target: |
|
loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) |
|
|
|
if torch.cuda.is_available(): |
|
loss = loss.cuda() |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
def _create_loss(name, loss_config, weight, ignore_index, pos_weight): |
|
if name == 'BCEWithLogitsLoss': |
|
return nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
elif name == 'BCEDiceLoss': |
|
alpha = loss_config.get('alphs', 1.) |
|
beta = loss_config.get('beta', 1.) |
|
return BCEDiceLoss(alpha, beta) |
|
elif name == 'CrossEntropyLoss': |
|
if ignore_index is None: |
|
ignore_index = -100 |
|
return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) |
|
elif name == 'WeightedCrossEntropyLoss': |
|
if ignore_index is None: |
|
ignore_index = -100 |
|
return WeightedCrossEntropyLoss(ignore_index=ignore_index) |
|
elif name == 'PixelWiseCrossEntropyLoss': |
|
return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) |
|
elif name == 'GeneralizedDiceLoss': |
|
normalization = loss_config.get('normalization', 'sigmoid') |
|
return GeneralizedDiceLoss(normalization=normalization) |
|
elif name == 'DiceLoss': |
|
normalization = loss_config.get('normalization', 'sigmoid') |
|
return DiceLoss(weight=weight, normalization=normalization) |
|
elif name == 'MSELoss': |
|
return MSELoss() |
|
elif name == 'SmoothL1Loss': |
|
return SmoothL1Loss() |
|
elif name == 'L1Loss': |
|
return L1Loss() |
|
elif name == 'WeightedSmoothL1Loss': |
|
return WeightedSmoothL1Loss(threshold=loss_config['threshold'], |
|
initial_weight=loss_config['initial_weight'], |
|
apply_below_threshold=loss_config.get('apply_below_threshold', True)) |
|
else: |
|
raise RuntimeError(f"Unsupported loss function: '{name}'") |
|
|