Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from scipy import stats | |
from sklearn import metrics | |
import torch | |
def d_prime(auc): | |
standard_normal = stats.norm() | |
d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) | |
return d_prime | |
def concat_all_gather(tensor): | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
tensors_gather = [ | |
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) | |
] | |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |
def calculate_stats(output, target): | |
"""Calculate statistics including mAP, AUC, etc. | |
Args: | |
output: 2d array, (samples_num, classes_num) | |
target: 2d array, (samples_num, classes_num) | |
Returns: | |
stats: list of statistic of each class. | |
""" | |
classes_num = target.shape[-1] | |
stats = [] | |
# Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet | |
acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) | |
# Class-wise statistics | |
for k in range(classes_num): | |
# Average precision | |
avg_precision = metrics.average_precision_score( | |
target[:, k], output[:, k], average=None | |
) | |
# AUC | |
# auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) | |
# Precisions, recalls | |
(precisions, recalls, thresholds) = metrics.precision_recall_curve( | |
target[:, k], output[:, k] | |
) | |
# FPR, TPR | |
(fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) | |
save_every_steps = 1000 # Sample statistics to reduce size | |
dict = { | |
"precisions": precisions[0::save_every_steps], | |
"recalls": recalls[0::save_every_steps], | |
"AP": avg_precision, | |
"fpr": fpr[0::save_every_steps], | |
"fnr": 1.0 - tpr[0::save_every_steps], | |
# 'auc': auc, | |
# note acc is not class-wise, this is just to keep consistent with other metrics | |
"acc": acc, | |
} | |
stats.append(dict) | |
return stats | |