import importlib |
import numpy as np |
import torch |
from skimage import measure |
from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error |
from pytorch3dunet.unet3d.losses import compute_per_channel_dice |
from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy |
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy |
logger = get_logger('EvalMetric') |
class DiceCoefficient: |
"""Computes Dice Coefficient. |
Generalized to multiple channels by computing per-channel Dice Score |
(as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average. |
Input is expected to be probabilities instead of logits. |
This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). |
DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. |
""" |
def __init__(self, epsilon=1e-6, **kwargs): |
self.epsilon = epsilon |
def __call__(self, input, target): |
return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon)) |
class MeanIoU: |
""" |
Computes IoU for each class separately and then averages over all classes. |
""" |
def __init__(self, skip_channels=(), ignore_index=None, **kwargs): |
""" |
:param skip_channels: list/tuple of channels to be ignored from the IoU computation |
:param ignore_index: id of the label to be ignored from IoU computation |
""" |
self.ignore_index = ignore_index |
self.skip_channels = skip_channels |
def __call__(self, input, target): |
""" |
:param input: 5D probability maps torch float tensor (NxCxDxHxW) |
:param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot |
:return: intersection over union averaged over all channels |
""" |
assert input.dim() == 5 |
n_classes = input.size()[1] |
if target.dim() == 4: |
target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) |
assert input.size() == target.size() |
per_batch_iou = [] |
for _input, _target in zip(input, target): |
binary_prediction = self._binarize_predictions(_input, n_classes) |
if self.ignore_index is not None: |
mask = _target == self.ignore_index |
binary_prediction[mask] = 0 |
_target[mask] = 0 |
binary_prediction = binary_prediction.byte() |
_target = _target.byte() |
per_channel_iou = [] |
for c in range(n_classes): |
if c in self.skip_channels: |
continue |
per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) |
assert per_channel_iou, "All channels were ignored from the computation" |
mean_iou = torch.mean(torch.tensor(per_channel_iou)) |
per_batch_iou.append(mean_iou) |
return torch.mean(torch.tensor(per_batch_iou)) |
def _binarize_predictions(self, input, n_classes): |
""" |
Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the |
same size as the input tensor. |
""" |
if n_classes == 1: |
result = input > 0.5 |
return result.long() |
_, max_index = torch.max(input, dim=0, keepdim=True) |
return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) |
def _jaccard_index(self, prediction, target): |
""" |
Computes IoU for a given target and prediction tensors |
""" |
return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) |
class AdaptedRandError: |
""" |
A functor which computes an Adapted Rand error as defined by the SNEMI3D contest |
(http://brainiac2.mit.edu/SNEMI3D/evaluation). |
This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`) |
and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case |
it's enough to extend this class and implement the `input_to_segm` method. |
Args: |
use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first. |
""" |
def __init__(self, use_last_target=False, ignore_index=None, **kwargs): |
self.use_last_target = use_last_target |
self.ignore_index = ignore_index |
def __call__(self, input, target): |
""" |
Compute ARand Error for each input, target pair in the batch and return the mean value. |
Args: |
input (torch.tensor): 5D (NCDHW) output from the network |
target (torch.tensor): 5D (NCDHW) ground truth segmentation |
Returns: |
average ARand Error across the batch |
""" |
input, target = convert_to_numpy(input, target) |
if self.use_last_target: |
target = target[:, -1, ...] |
else: |
target = target[:, 0, ...] |
target = target.astype(np.int32) |
if self.ignore_index is not None: |
target[target == self.ignore_index] = 0 |
per_batch_arand = [] |
for _input, _target in zip(input, target): |
if np.all(_target == _target.flat[0]): |
logger.info('Skipping ARandError computation: only 1 label present in the ground truth') |
per_batch_arand.append(0.) |
continue |
segm = self.input_to_segm(_input) |
assert segm.ndim == 4 |
per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm] |
per_batch_arand.append(np.min(per_channel_arand)) |
mean_arand = torch.mean(torch.tensor(per_batch_arand)) |
logger.info(f'ARand: {mean_arand.item()}') |
return mean_arand |
def input_to_segm(self, input): |
""" |
Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary |
pmaps then one option would be to threshold it and run connected components in order to return the segmentation. |
:param input: 4D tensor (CDHW) |
:return: segmentation volume either 4D (segmentation per channel) |
""" |
return input |
class BoundaryAdaptedRandError(AdaptedRandError): |
""" |
Compute ARand between the input boundary map and target segmentation. |
Boundary map is thresholded, and connected components is run to get the predicted segmentation |
""" |
def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True, |
save_plots=False, plots_dir='.', **kwargs): |
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots, |
plots_dir=plots_dir, **kwargs) |
if thresholds is None: |
thresholds = [0.3, 0.4, 0.5, 0.6] |
assert isinstance(thresholds, list) |
self.thresholds = thresholds |
self.input_channel = input_channel |
self.invert_pmaps = invert_pmaps |
def input_to_segm(self, input): |
if self.input_channel is not None: |
input = np.expand_dims(input[self.input_channel], axis=0) |
segs = [] |
for predictions in input: |
for th in self.thresholds: |
predictions = predictions > th |
if self.invert_pmaps: |
predictions = np.logical_not(predictions) |
predictions = predictions.astype(np.uint8) |
seg = measure.label(predictions, background=0, connectivity=1) |
segs.append(seg) |
return np.stack(segs) |
class GenericAdaptedRandError(AdaptedRandError): |
def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None, |
**kwargs): |
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs) |
assert isinstance(input_channels, list) or isinstance(input_channels, tuple) |
self.input_channels = input_channels |
if thresholds is None: |
thresholds = [0.3, 0.4, 0.5, 0.6] |
assert isinstance(thresholds, list) |
self.thresholds = thresholds |
if invert_channels is None: |
invert_channels = [] |
self.invert_channels = invert_channels |
def input_to_segm(self, input): |
results = [] |
for i in self.input_channels: |
c = input[i] |
if i in self.invert_channels: |
c = 1 - c |
results.append(c) |
input = np.stack(results) |
segs = [] |
for predictions in input: |
for th in self.thresholds: |
seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1) |
segs.append(seg) |
return np.stack(segs) |
class GenericAveragePrecision: |
def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs): |
self.min_instance_size = min_instance_size |
self.use_last_target = use_last_target |
assert metric in ['ap', 'acc'] |
if metric == 'ap': |
self.metric = AveragePrecision() |
else: |
self.metric = Accuracy(iou_threshold=0.5) |
def __call__(self, input, target): |
if target.dim() == 5: |
if self.use_last_target: |
target = target[:, -1, ...] |
else: |
target = target[:, 0, ...] |
input1 = input2 = input |
multi_head = isinstance(input, tuple) |
if multi_head: |
input1, input2 = input |
input1, input2, target = convert_to_numpy(input1, input2, target) |
batch_aps = [] |
i_batch = 0 |
for inp1, inp2, tar in zip(input1, input2, target): |
if multi_head: |
inp = (inp1, inp2) |
else: |
inp = inp1 |
segs = self.input_to_seg(inp, tar) |
assert segs.ndim == 4 |
tar = self.target_to_seg(tar) |
tar = self._filter_instances(tar) |
segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs] |
logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}') |
batch_aps.append(np.max(segs_aps)) |
i_batch += 1 |
return torch.tensor(batch_aps).mean() |
def _filter_instances(self, input): |
""" |
Filters instances smaller than 'min_instance_size' by overriding them with 0-index |
:param input: input instance segmentation |
""" |
if self.min_instance_size is not None: |
labels, counts = np.unique(input, return_counts=True) |
for label, count in zip(labels, counts): |
if count < self.min_instance_size: |
input[input == label] = 0 |
return input |
def input_to_seg(self, input, target=None): |
raise NotImplementedError |
def target_to_seg(self, target): |
return target |
class BlobsAveragePrecision(GenericAveragePrecision): |
""" |
Computes Average Precision given foreground prediction and ground truth instance segmentation. |
""" |
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs): |
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) |
if thresholds is None: |
thresholds = [0.4, 0.5, 0.6, 0.7, 0.8] |
assert isinstance(thresholds, list) |
self.thresholds = thresholds |
self.input_channel = input_channel |
def input_to_seg(self, input, target=None): |
input = input[self.input_channel] |
segs = [] |
for th in self.thresholds: |
mask = (input > th).astype(np.uint8) |
seg = measure.label(mask, background=0, connectivity=1) |
segs.append(seg) |
return np.stack(segs) |
class BlobsBoundaryAveragePrecision(GenericAveragePrecision): |
""" |
Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation. |
Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component |
""" |
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs): |
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) |
if thresholds is None: |
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7] |
assert isinstance(thresholds, list) |
self.thresholds = thresholds |
def input_to_seg(self, input, target=None): |
input = input[0] - input[1] |
segs = [] |
for th in self.thresholds: |
mask = (input > th).astype(np.uint8) |
seg = measure.label(mask, background=0, connectivity=1) |
segs.append(seg) |
return np.stack(segs) |
class BoundaryAveragePrecision(GenericAveragePrecision): |
""" |
Computes Average Precision given boundary prediction and ground truth instance segmentation. |
""" |
def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs): |
super().__init__(min_instance_size=min_instance_size, use_last_target=True) |
if thresholds is None: |
thresholds = [0.3, 0.4, 0.5, 0.6] |
assert isinstance(thresholds, list) |
self.thresholds = thresholds |
self.input_channel = input_channel |
def input_to_seg(self, input, target=None): |
input = input[self.input_channel] |
segs = [] |
for th in self.thresholds: |
seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1) |
segs.append(seg) |
return np.stack(segs) |
class PSNR: |
""" |
Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task |
""" |
def __init__(self, **kwargs): |
pass |
def __call__(self, input, target): |
input, target = convert_to_numpy(input, target) |
return peak_signal_noise_ratio(target, input) |
class MSE: |
""" |
Computes MSE between input and target |
""" |
def __init__(self, **kwargs): |
pass |
def __call__(self, input, target): |
input, target = convert_to_numpy(input, target) |
return mean_squared_error(input, target) |
def get_evaluation_metric(config): |
""" |
Returns the evaluation metric function based on provided configuration |
:param config: (dict) a top level configuration object containing the 'eval_metric' key |
:return: an instance of the evaluation metric |
""" |
def _metric_class(class_name): |
m = importlib.import_module('pytorch3dunet.unet3d.metrics') |
clazz = getattr(m, class_name) |
return clazz |
assert 'eval_metric' in config, 'Could not find evaluation metric configuration' |
metric_config = config['eval_metric'] |
metric_class = _metric_class(metric_config['name']) |
return metric_class(**metric_config) |