import torch import numpy as np import cv2 def _threshold(x, threshold=None): if threshold is not None: return (x > threshold).type(x.dtype) else: return x def _list_tensor(x, y): m = torch.nn.Sigmoid() if type(x) is list: x = torch.tensor(np.array(x)) y = torch.tensor(np.array(y)) if x.min() < 0: x = m(x) else: x, y = x, y if x.min() < 0: x = m(x) return x, y def iou(pr, gt, eps=1e-7, threshold = 0.5): pr_, gt_ = _list_tensor(pr, gt) pr_ = _threshold(pr_, threshold=threshold) gt_ = _threshold(gt_, threshold=threshold) intersection = torch.sum(gt_ * pr_,dim=[1,2,3]) union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3]) - intersection return ((intersection + eps) / (union + eps)).cpu().numpy() def dice(pr, gt, eps=1e-7, threshold = 0.5): pr_, gt_ = _list_tensor(pr, gt) pr_ = _threshold(pr_, threshold=threshold) gt_ = _threshold(gt_, threshold=threshold) intersection = torch.sum(gt_ * pr_,dim=[1,2,3]) union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3]) return ((2. * intersection +eps) / (union + eps)).cpu().numpy() def SegMetrics(pred, label, metrics): metric_list = [] if isinstance(metrics, str): metrics = [metrics, ] for i, metric in enumerate(metrics): if not isinstance(metric, str): continue elif metric == 'iou': metric_list.append(np.mean(iou(pred, label))) elif metric == 'dice': metric_list.append(np.mean(dice(pred, label))) else: raise ValueError('metric %s not recognized' % metric) if pred is not None: metric = np.array(metric_list) else: raise ValueError('metric mistakes in calculations') return metric