|
import importlib |
|
|
|
import numpy as np |
|
import torch |
|
from skimage import measure |
|
from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error |
|
|
|
from pytorch3dunet.unet3d.losses import compute_per_channel_dice |
|
from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy |
|
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy |
|
|
|
logger = get_logger('EvalMetric') |
|
|
|
|
|
class DiceCoefficient: |
|
"""Computes Dice Coefficient. |
|
Generalized to multiple channels by computing per-channel Dice Score |
|
(as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average. |
|
Input is expected to be probabilities instead of logits. |
|
This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). |
|
DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. |
|
""" |
|
|
|
def __init__(self, epsilon=1e-6, **kwargs): |
|
self.epsilon = epsilon |
|
|
|
def __call__(self, input, target): |
|
|
|
return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon)) |
|
|
|
|
|
class MeanIoU: |
|
""" |
|
Computes IoU for each class separately and then averages over all classes. |
|
""" |
|
|
|
def __init__(self, skip_channels=(), ignore_index=None, **kwargs): |
|
""" |
|
:param skip_channels: list/tuple of channels to be ignored from the IoU computation |
|
:param ignore_index: id of the label to be ignored from IoU computation |
|
""" |
|
self.ignore_index = ignore_index |
|
self.skip_channels = skip_channels |
|
|
|
def __call__(self, input, target): |
|
""" |
|
:param input: 5D probability maps torch float tensor (NxCxDxHxW) |
|
:param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot |
|
:return: intersection over union averaged over all channels |
|
""" |
|
assert input.dim() == 5 |
|
|
|
n_classes = input.size()[1] |
|
|
|
if target.dim() == 4: |
|
target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) |
|
|
|
assert input.size() == target.size() |
|
|
|
per_batch_iou = [] |
|
for _input, _target in zip(input, target): |
|
binary_prediction = self._binarize_predictions(_input, n_classes) |
|
|
|
if self.ignore_index is not None: |
|
|
|
mask = _target == self.ignore_index |
|
binary_prediction[mask] = 0 |
|
_target[mask] = 0 |
|
|
|
|
|
binary_prediction = binary_prediction.byte() |
|
_target = _target.byte() |
|
|
|
per_channel_iou = [] |
|
for c in range(n_classes): |
|
if c in self.skip_channels: |
|
continue |
|
|
|
per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) |
|
|
|
assert per_channel_iou, "All channels were ignored from the computation" |
|
mean_iou = torch.mean(torch.tensor(per_channel_iou)) |
|
per_batch_iou.append(mean_iou) |
|
|
|
return torch.mean(torch.tensor(per_batch_iou)) |
|
|
|
def _binarize_predictions(self, input, n_classes): |
|
""" |
|
Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the |
|
same size as the input tensor. |
|
""" |
|
if n_classes == 1: |
|
|
|
result = input > 0.5 |
|
return result.long() |
|
|
|
_, max_index = torch.max(input, dim=0, keepdim=True) |
|
return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) |
|
|
|
def _jaccard_index(self, prediction, target): |
|
""" |
|
Computes IoU for a given target and prediction tensors |
|
""" |
|
return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) |
|
|
|
|
|
class AdaptedRandError: |
|
""" |
|
A functor which computes an Adapted Rand error as defined by the SNEMI3D contest |
|
(http://brainiac2.mit.edu/SNEMI3D/evaluation). |
|
|
|
This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`) |
|
and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case |
|
it's enough to extend this class and implement the `input_to_segm` method. |
|
|
|
Args: |
|
use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first. |
|
""" |
|
|
|
def __init__(self, use_last_target=False, ignore_index=None, **kwargs): |
|
self.use_last_target = use_last_target |
|
self.ignore_index = ignore_index |
|
|
|
def __call__(self, input, target): |
|
""" |
|
Compute ARand Error for each input, target pair in the batch and return the mean value. |
|
|
|
Args: |
|
input (torch.tensor): 5D (NCDHW) output from the network |
|
target (torch.tensor): 5D (NCDHW) ground truth segmentation |
|
|
|
Returns: |
|
average ARand Error across the batch |
|
""" |
|
|
|
|
|
input, target = convert_to_numpy(input, target) |
|
if self.use_last_target: |
|
target = target[:, -1, ...] |
|
else: |
|
|
|
target = target[:, 0, ...] |
|
|
|
|
|
target = target.astype(np.int32) |
|
|
|
if self.ignore_index is not None: |
|
target[target == self.ignore_index] = 0 |
|
|
|
per_batch_arand = [] |
|
for _input, _target in zip(input, target): |
|
if np.all(_target == _target.flat[0]): |
|
logger.info('Skipping ARandError computation: only 1 label present in the ground truth') |
|
per_batch_arand.append(0.) |
|
continue |
|
|
|
|
|
segm = self.input_to_segm(_input) |
|
assert segm.ndim == 4 |
|
|
|
|
|
per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm] |
|
per_batch_arand.append(np.min(per_channel_arand)) |
|
|
|
|
|
mean_arand = torch.mean(torch.tensor(per_batch_arand)) |
|
logger.info(f'ARand: {mean_arand.item()}') |
|
return mean_arand |
|
|
|
def input_to_segm(self, input): |
|
""" |
|
Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary |
|
pmaps then one option would be to threshold it and run connected components in order to return the segmentation. |
|
|
|
:param input: 4D tensor (CDHW) |
|
:return: segmentation volume either 4D (segmentation per channel) |
|
""" |
|
|
|
return input |
|
|
|
|
|
class BoundaryAdaptedRandError(AdaptedRandError): |
|
""" |
|
Compute ARand between the input boundary map and target segmentation. |
|
Boundary map is thresholded, and connected components is run to get the predicted segmentation |
|
""" |
|
|
|
def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True, |
|
save_plots=False, plots_dir='.', **kwargs): |
|
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots, |
|
plots_dir=plots_dir, **kwargs) |
|
|
|
if thresholds is None: |
|
thresholds = [0.3, 0.4, 0.5, 0.6] |
|
assert isinstance(thresholds, list) |
|
self.thresholds = thresholds |
|
self.input_channel = input_channel |
|
self.invert_pmaps = invert_pmaps |
|
|
|
def input_to_segm(self, input): |
|
if self.input_channel is not None: |
|
input = np.expand_dims(input[self.input_channel], axis=0) |
|
|
|
segs = [] |
|
for predictions in input: |
|
for th in self.thresholds: |
|
|
|
predictions = predictions > th |
|
|
|
if self.invert_pmaps: |
|
|
|
|
|
predictions = np.logical_not(predictions) |
|
|
|
predictions = predictions.astype(np.uint8) |
|
|
|
seg = measure.label(predictions, background=0, connectivity=1) |
|
segs.append(seg) |
|
|
|
return np.stack(segs) |
|
|
|
|
|
class GenericAdaptedRandError(AdaptedRandError): |
|
def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None, |
|
**kwargs): |
|
|
|
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs) |
|
assert isinstance(input_channels, list) or isinstance(input_channels, tuple) |
|
self.input_channels = input_channels |
|
if thresholds is None: |
|
thresholds = [0.3, 0.4, 0.5, 0.6] |
|
assert isinstance(thresholds, list) |
|
self.thresholds = thresholds |
|
if invert_channels is None: |
|
invert_channels = [] |
|
self.invert_channels = invert_channels |
|
|
|
def input_to_segm(self, input): |
|
|
|
results = [] |
|
for i in self.input_channels: |
|
c = input[i] |
|
|
|
if i in self.invert_channels: |
|
c = 1 - c |
|
results.append(c) |
|
|
|
input = np.stack(results) |
|
|
|
segs = [] |
|
for predictions in input: |
|
for th in self.thresholds: |
|
|
|
seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1) |
|
segs.append(seg) |
|
|
|
return np.stack(segs) |
|
|
|
|
|
class GenericAveragePrecision: |
|
def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs): |
|
self.min_instance_size = min_instance_size |
|
self.use_last_target = use_last_target |
|
assert metric in ['ap', 'acc'] |
|
if metric == 'ap': |
|
|
|
self.metric = AveragePrecision() |
|
else: |
|
|
|
self.metric = Accuracy(iou_threshold=0.5) |
|
|
|
def __call__(self, input, target): |
|
if target.dim() == 5: |
|
if self.use_last_target: |
|
target = target[:, -1, ...] |
|
else: |
|
|
|
target = target[:, 0, ...] |
|
|
|
input1 = input2 = input |
|
multi_head = isinstance(input, tuple) |
|
if multi_head: |
|
input1, input2 = input |
|
|
|
input1, input2, target = convert_to_numpy(input1, input2, target) |
|
|
|
batch_aps = [] |
|
i_batch = 0 |
|
|
|
for inp1, inp2, tar in zip(input1, input2, target): |
|
if multi_head: |
|
inp = (inp1, inp2) |
|
else: |
|
inp = inp1 |
|
|
|
segs = self.input_to_seg(inp, tar) |
|
assert segs.ndim == 4 |
|
|
|
tar = self.target_to_seg(tar) |
|
|
|
|
|
tar = self._filter_instances(tar) |
|
|
|
|
|
segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs] |
|
|
|
logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}') |
|
|
|
batch_aps.append(np.max(segs_aps)) |
|
i_batch += 1 |
|
|
|
return torch.tensor(batch_aps).mean() |
|
|
|
def _filter_instances(self, input): |
|
""" |
|
Filters instances smaller than 'min_instance_size' by overriding them with 0-index |
|
:param input: input instance segmentation |
|
""" |
|
if self.min_instance_size is not None: |
|
labels, counts = np.unique(input, return_counts=True) |
|
for label, count in zip(labels, counts): |
|
if count < self.min_instance_size: |
|
input[input == label] = 0 |
|
return input |
|
|
|
def input_to_seg(self, input, target=None): |
|
raise NotImplementedError |
|
|
|
def target_to_seg(self, target): |
|
return target |
|
|
|
|
|
class BlobsAveragePrecision(GenericAveragePrecision): |
|
""" |
|
Computes Average Precision given foreground prediction and ground truth instance segmentation. |
|
""" |
|
|
|
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs): |
|
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) |
|
if thresholds is None: |
|
thresholds = [0.4, 0.5, 0.6, 0.7, 0.8] |
|
assert isinstance(thresholds, list) |
|
self.thresholds = thresholds |
|
self.input_channel = input_channel |
|
|
|
def input_to_seg(self, input, target=None): |
|
input = input[self.input_channel] |
|
segs = [] |
|
for th in self.thresholds: |
|
|
|
mask = (input > th).astype(np.uint8) |
|
seg = measure.label(mask, background=0, connectivity=1) |
|
segs.append(seg) |
|
return np.stack(segs) |
|
|
|
|
|
class BlobsBoundaryAveragePrecision(GenericAveragePrecision): |
|
""" |
|
Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation. |
|
Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component |
|
""" |
|
|
|
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs): |
|
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) |
|
if thresholds is None: |
|
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7] |
|
assert isinstance(thresholds, list) |
|
self.thresholds = thresholds |
|
|
|
def input_to_seg(self, input, target=None): |
|
|
|
input = input[0] - input[1] |
|
segs = [] |
|
for th in self.thresholds: |
|
|
|
mask = (input > th).astype(np.uint8) |
|
seg = measure.label(mask, background=0, connectivity=1) |
|
segs.append(seg) |
|
return np.stack(segs) |
|
|
|
|
|
class BoundaryAveragePrecision(GenericAveragePrecision): |
|
""" |
|
Computes Average Precision given boundary prediction and ground truth instance segmentation. |
|
""" |
|
|
|
def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs): |
|
super().__init__(min_instance_size=min_instance_size, use_last_target=True) |
|
if thresholds is None: |
|
thresholds = [0.3, 0.4, 0.5, 0.6] |
|
assert isinstance(thresholds, list) |
|
self.thresholds = thresholds |
|
self.input_channel = input_channel |
|
|
|
def input_to_seg(self, input, target=None): |
|
input = input[self.input_channel] |
|
segs = [] |
|
for th in self.thresholds: |
|
seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1) |
|
segs.append(seg) |
|
return np.stack(segs) |
|
|
|
|
|
class PSNR: |
|
""" |
|
Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
pass |
|
|
|
def __call__(self, input, target): |
|
input, target = convert_to_numpy(input, target) |
|
return peak_signal_noise_ratio(target, input) |
|
|
|
|
|
class MSE: |
|
""" |
|
Computes MSE between input and target |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
pass |
|
|
|
def __call__(self, input, target): |
|
input, target = convert_to_numpy(input, target) |
|
return mean_squared_error(input, target) |
|
|
|
|
|
def get_evaluation_metric(config): |
|
""" |
|
Returns the evaluation metric function based on provided configuration |
|
:param config: (dict) a top level configuration object containing the 'eval_metric' key |
|
:return: an instance of the evaluation metric |
|
""" |
|
|
|
def _metric_class(class_name): |
|
m = importlib.import_module('pytorch3dunet.unet3d.metrics') |
|
clazz = getattr(m, class_name) |
|
return clazz |
|
|
|
assert 'eval_metric' in config, 'Could not find evaluation metric configuration' |
|
metric_config = config['eval_metric'] |
|
metric_class = _metric_class(metric_config['name']) |
|
return metric_class(**metric_config) |
|
|