import torch def _take_channels(*xs, ignore_channels=None): if ignore_channels is None: return xs else: channels = [ channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels ] xs = [ torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs ] return xs def _threshold(x, threshold=None): if threshold is not None: return (x > threshold).type(x.dtype) else: return x def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): """Calculate Intersection over Union between ground truth and prediction Args: pr (torch.Tensor): predicted tensor gt (torch.Tensor): ground truth tensor eps (float): epsilon to avoid zero division threshold: threshold for outputs binarization Returns: float: IoU (Jaccard) score """ pr = _threshold(pr, threshold=threshold) pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) intersection = torch.sum(gt * pr) union = torch.sum(gt) + torch.sum(pr) - intersection + eps return (intersection + eps) / union jaccard = iou def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): """Calculate F-score between ground truth and prediction Args: pr (torch.Tensor): predicted tensor gt (torch.Tensor): ground truth tensor beta (float): positive constant eps (float): epsilon to avoid zero division threshold: threshold for outputs binarization Returns: float: F score """ pr = _threshold(pr, threshold=threshold) pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) tp = torch.sum(gt * pr) fp = torch.sum(pr) - tp fn = torch.sum(gt) - tp score = ((1 + beta ** 2) * tp + eps) / ( (1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps ) return score def accuracy(pr, gt, threshold=0.5, ignore_channels=None): """Calculate accuracy score between ground truth and prediction Args: pr (torch.Tensor): predicted tensor gt (torch.Tensor): ground truth tensor eps (float): epsilon to avoid zero division threshold: threshold for outputs binarization Returns: float: precision score """ pr = _threshold(pr, threshold=threshold) pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) tp = torch.sum(gt == pr, dtype=pr.dtype) score = tp / gt.view(-1).shape[0] return score def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): """Calculate precision score between ground truth and prediction Args: pr (torch.Tensor): predicted tensor gt (torch.Tensor): ground truth tensor eps (float): epsilon to avoid zero division threshold: threshold for outputs binarization Returns: float: precision score """ pr = _threshold(pr, threshold=threshold) pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) tp = torch.sum(gt * pr) fp = torch.sum(pr) - tp score = (tp + eps) / (tp + fp + eps) return score def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): """Calculate Recall between ground truth and prediction Args: pr (torch.Tensor): A list of predicted elements gt (torch.Tensor): A list of elements that are to be predicted eps (float): epsilon to avoid zero division threshold: threshold for outputs binarization Returns: float: recall score """ pr = _threshold(pr, threshold=threshold) pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) tp = torch.sum(gt * pr) fn = torch.sum(gt) - tp score = (tp + eps) / (tp + fn + eps) return score