""" File copied from https://github.com/nicola-decao/diffmask/blob/master/diffmask/utils/util.py """ import torch from torch import Tensor def accuracy_precision_recall_f1( y_pred: Tensor, y_true: Tensor, average: bool = True ) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculates the accuracy, precision, recall and f1 score given the predicted and true labels. Args: y_pred (Tensor): predicted labels y_true (Tensor): true labels average (bool): whether to average the scores or not Returns: a tuple of the accuracy, precision, recall and f1 score """ M = confusion_matrix(y_pred, y_true) tp = M.diagonal(dim1=-2, dim2=-1).float() precision_den = M.sum(-2) precision = torch.where( precision_den == 0, torch.zeros_like(tp), tp / precision_den ) recall_den = M.sum(-1) recall = torch.where(recall_den == 0, torch.ones_like(tp), tp / recall_den) f1_den = precision + recall f1 = torch.where( f1_den == 0, torch.zeros_like(tp), 2 * (precision * recall) / f1_den ) # noinspection PyTypeChecker return ((y_pred == y_true).float().mean(-1),) + ( tuple(e.mean(-1) for e in (precision, recall, f1)) if average else (precision, recall, f1) ) def confusion_matrix(y_pred: Tensor, y_true: Tensor) -> Tensor: """Creates a confusion matrix given the predicted and true labels.""" device = y_pred.device labels = max(y_pred.max().item() + 1, y_true.max().item() + 1) return ( ( torch.stack((y_true, y_pred), -1).unsqueeze(-2).unsqueeze(-2) == torch.stack( ( torch.arange(labels, device=device).unsqueeze(-1).repeat(1, labels), torch.arange(labels, device=device).unsqueeze(-2).repeat(labels, 1), ), -1, ) ) .all(-1) .sum(-3) )