hyzhou's picture
upload everything
cca9b7e
raw
history blame
16.9 kB
import importlib
import numpy as np
import torch
from skimage import measure
from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error
from pytorch3dunet.unet3d.losses import compute_per_channel_dice
from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy
logger = get_logger('EvalMetric')
class DiceCoefficient:
"""Computes Dice Coefficient.
Generalized to multiple channels by computing per-channel Dice Score
(as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average.
Input is expected to be probabilities instead of logits.
This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
"""
def __init__(self, epsilon=1e-6, **kwargs):
self.epsilon = epsilon
def __call__(self, input, target):
# Average across channels in order to get the final score
return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon))
class MeanIoU:
"""
Computes IoU for each class separately and then averages over all classes.
"""
def __init__(self, skip_channels=(), ignore_index=None, **kwargs):
"""
:param skip_channels: list/tuple of channels to be ignored from the IoU computation
:param ignore_index: id of the label to be ignored from IoU computation
"""
self.ignore_index = ignore_index
self.skip_channels = skip_channels
def __call__(self, input, target):
"""
:param input: 5D probability maps torch float tensor (NxCxDxHxW)
:param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
:return: intersection over union averaged over all channels
"""
assert input.dim() == 5
n_classes = input.size()[1]
if target.dim() == 4:
target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index)
assert input.size() == target.size()
per_batch_iou = []
for _input, _target in zip(input, target):
binary_prediction = self._binarize_predictions(_input, n_classes)
if self.ignore_index is not None:
# zero out ignore_index
mask = _target == self.ignore_index
binary_prediction[mask] = 0
_target[mask] = 0
# convert to uint8 just in case
binary_prediction = binary_prediction.byte()
_target = _target.byte()
per_channel_iou = []
for c in range(n_classes):
if c in self.skip_channels:
continue
per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c]))
assert per_channel_iou, "All channels were ignored from the computation"
mean_iou = torch.mean(torch.tensor(per_channel_iou))
per_batch_iou.append(mean_iou)
return torch.mean(torch.tensor(per_batch_iou))
def _binarize_predictions(self, input, n_classes):
"""
Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the
same size as the input tensor.
"""
if n_classes == 1:
# for single channel input just threshold the probability map
result = input > 0.5
return result.long()
_, max_index = torch.max(input, dim=0, keepdim=True)
return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1)
def _jaccard_index(self, prediction, target):
"""
Computes IoU for a given target and prediction tensors
"""
return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8)
class AdaptedRandError:
"""
A functor which computes an Adapted Rand error as defined by the SNEMI3D contest
(http://brainiac2.mit.edu/SNEMI3D/evaluation).
This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`)
and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case
it's enough to extend this class and implement the `input_to_segm` method.
Args:
use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first.
"""
def __init__(self, use_last_target=False, ignore_index=None, **kwargs):
self.use_last_target = use_last_target
self.ignore_index = ignore_index
def __call__(self, input, target):
"""
Compute ARand Error for each input, target pair in the batch and return the mean value.
Args:
input (torch.tensor): 5D (NCDHW) output from the network
target (torch.tensor): 5D (NCDHW) ground truth segmentation
Returns:
average ARand Error across the batch
"""
# converts input and target to numpy arrays
input, target = convert_to_numpy(input, target)
if self.use_last_target:
target = target[:, -1, ...] # 4D
else:
# use 1st target channel
target = target[:, 0, ...] # 4D
# ensure target is of integer type
target = target.astype(np.int32)
if self.ignore_index is not None:
target[target == self.ignore_index] = 0
per_batch_arand = []
for _input, _target in zip(input, target):
if np.all(_target == _target.flat[0]): # skip ARand eval if there is only one label in the patch due to zero-division
logger.info('Skipping ARandError computation: only 1 label present in the ground truth')
per_batch_arand.append(0.)
continue
# convert _input to segmentation CDHW
segm = self.input_to_segm(_input)
assert segm.ndim == 4
# compute per channel arand and return the minimum value
per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm]
per_batch_arand.append(np.min(per_channel_arand))
# return mean arand error
mean_arand = torch.mean(torch.tensor(per_batch_arand))
logger.info(f'ARand: {mean_arand.item()}')
return mean_arand
def input_to_segm(self, input):
"""
Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary
pmaps then one option would be to threshold it and run connected components in order to return the segmentation.
:param input: 4D tensor (CDHW)
:return: segmentation volume either 4D (segmentation per channel)
"""
# by deafult assume that input is a segmentation volume itself
return input
class BoundaryAdaptedRandError(AdaptedRandError):
"""
Compute ARand between the input boundary map and target segmentation.
Boundary map is thresholded, and connected components is run to get the predicted segmentation
"""
def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True,
save_plots=False, plots_dir='.', **kwargs):
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots,
plots_dir=plots_dir, **kwargs)
if thresholds is None:
thresholds = [0.3, 0.4, 0.5, 0.6]
assert isinstance(thresholds, list)
self.thresholds = thresholds
self.input_channel = input_channel
self.invert_pmaps = invert_pmaps
def input_to_segm(self, input):
if self.input_channel is not None:
input = np.expand_dims(input[self.input_channel], axis=0)
segs = []
for predictions in input:
for th in self.thresholds:
# threshold probability maps
predictions = predictions > th
if self.invert_pmaps:
# for connected component analysis we need to treat boundary signal as background
# assign 0-label to boundary mask
predictions = np.logical_not(predictions)
predictions = predictions.astype(np.uint8)
# run connected components on the predicted mask; consider only 1-connectivity
seg = measure.label(predictions, background=0, connectivity=1)
segs.append(seg)
return np.stack(segs)
class GenericAdaptedRandError(AdaptedRandError):
def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None,
**kwargs):
super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs)
assert isinstance(input_channels, list) or isinstance(input_channels, tuple)
self.input_channels = input_channels
if thresholds is None:
thresholds = [0.3, 0.4, 0.5, 0.6]
assert isinstance(thresholds, list)
self.thresholds = thresholds
if invert_channels is None:
invert_channels = []
self.invert_channels = invert_channels
def input_to_segm(self, input):
# pick only the channels specified in the input_channels
results = []
for i in self.input_channels:
c = input[i]
# invert channel if necessary
if i in self.invert_channels:
c = 1 - c
results.append(c)
input = np.stack(results)
segs = []
for predictions in input:
for th in self.thresholds:
# run connected components on the predicted mask; consider only 1-connectivity
seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1)
segs.append(seg)
return np.stack(segs)
class GenericAveragePrecision:
def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs):
self.min_instance_size = min_instance_size
self.use_last_target = use_last_target
assert metric in ['ap', 'acc']
if metric == 'ap':
# use AveragePrecision
self.metric = AveragePrecision()
else:
# use Accuracy at 0.5 IoU
self.metric = Accuracy(iou_threshold=0.5)
def __call__(self, input, target):
if target.dim() == 5:
if self.use_last_target:
target = target[:, -1, ...] # 4D
else:
# use 1st target channel
target = target[:, 0, ...] # 4D
input1 = input2 = input
multi_head = isinstance(input, tuple)
if multi_head:
input1, input2 = input
input1, input2, target = convert_to_numpy(input1, input2, target)
batch_aps = []
i_batch = 0
# iterate over the batch
for inp1, inp2, tar in zip(input1, input2, target):
if multi_head:
inp = (inp1, inp2)
else:
inp = inp1
segs = self.input_to_seg(inp, tar) # expects 4D
assert segs.ndim == 4
# convert target to seg
tar = self.target_to_seg(tar)
# filter small instances if necessary
tar = self._filter_instances(tar)
# compute average precision per channel
segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs]
logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}')
# save max AP
batch_aps.append(np.max(segs_aps))
i_batch += 1
return torch.tensor(batch_aps).mean()
def _filter_instances(self, input):
"""
Filters instances smaller than 'min_instance_size' by overriding them with 0-index
:param input: input instance segmentation
"""
if self.min_instance_size is not None:
labels, counts = np.unique(input, return_counts=True)
for label, count in zip(labels, counts):
if count < self.min_instance_size:
input[input == label] = 0
return input
def input_to_seg(self, input, target=None):
raise NotImplementedError
def target_to_seg(self, target):
return target
class BlobsAveragePrecision(GenericAveragePrecision):
"""
Computes Average Precision given foreground prediction and ground truth instance segmentation.
"""
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs):
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
if thresholds is None:
thresholds = [0.4, 0.5, 0.6, 0.7, 0.8]
assert isinstance(thresholds, list)
self.thresholds = thresholds
self.input_channel = input_channel
def input_to_seg(self, input, target=None):
input = input[self.input_channel]
segs = []
for th in self.thresholds:
# threshold and run connected components
mask = (input > th).astype(np.uint8)
seg = measure.label(mask, background=0, connectivity=1)
segs.append(seg)
return np.stack(segs)
class BlobsBoundaryAveragePrecision(GenericAveragePrecision):
"""
Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation.
Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component
"""
def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs):
super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
if thresholds is None:
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
assert isinstance(thresholds, list)
self.thresholds = thresholds
def input_to_seg(self, input, target=None):
# input = P_mask - P_boundary
input = input[0] - input[1]
segs = []
for th in self.thresholds:
# threshold and run connected components
mask = (input > th).astype(np.uint8)
seg = measure.label(mask, background=0, connectivity=1)
segs.append(seg)
return np.stack(segs)
class BoundaryAveragePrecision(GenericAveragePrecision):
"""
Computes Average Precision given boundary prediction and ground truth instance segmentation.
"""
def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs):
super().__init__(min_instance_size=min_instance_size, use_last_target=True)
if thresholds is None:
thresholds = [0.3, 0.4, 0.5, 0.6]
assert isinstance(thresholds, list)
self.thresholds = thresholds
self.input_channel = input_channel
def input_to_seg(self, input, target=None):
input = input[self.input_channel]
segs = []
for th in self.thresholds:
seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1)
segs.append(seg)
return np.stack(segs)
class PSNR:
"""
Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task
"""
def __init__(self, **kwargs):
pass
def __call__(self, input, target):
input, target = convert_to_numpy(input, target)
return peak_signal_noise_ratio(target, input)
class MSE:
"""
Computes MSE between input and target
"""
def __init__(self, **kwargs):
pass
def __call__(self, input, target):
input, target = convert_to_numpy(input, target)
return mean_squared_error(input, target)
def get_evaluation_metric(config):
"""
Returns the evaluation metric function based on provided configuration
:param config: (dict) a top level configuration object containing the 'eval_metric' key
:return: an instance of the evaluation metric
"""
def _metric_class(class_name):
m = importlib.import_module('pytorch3dunet.unet3d.metrics')
clazz = getattr(m, class_name)
return clazz
assert 'eval_metric' in config, 'Could not find evaluation metric configuration'
metric_config = config['eval_metric']
metric_class = _metric_class(metric_config['name'])
return metric_class(**metric_config)