Spaces:
Runtime error
Runtime error
""" | |
Metrics for multi-label text classification | |
""" | |
import numpy as np | |
from scipy.special import expit | |
import itertools | |
import copy | |
import torch | |
import pickle | |
import os | |
import json | |
# Metrics | |
from sklearn.metrics import ( | |
accuracy_score, | |
f1_score, | |
classification_report, | |
precision_score, | |
recall_score, | |
label_ranking_average_precision_score, | |
coverage_error | |
) | |
import time | |
def multilabel_metrics(data_args, id2label, label2id, fbr, training_args = None): | |
""" | |
Metrics function used for multilabel classification. | |
:fbr : A dict containing global thresholds to be used for selecting a class. | |
We use global thresholds because we want to handle unseen classes, | |
for which the threshold is not known in advance. | |
""" | |
func_call_counts = 0 | |
def compute_metrics(p): | |
# global func_call_counts | |
# Save the predictions, maintaining a global counter | |
if training_args is not None and training_args.local_rank <= 0: | |
preds_fol = os.path.join(training_args.output_dir, 'predictions') | |
os.makedirs(preds_fol, exist_ok = True) | |
func_call_counts = time.time() #np.random.randn() | |
pickle.dump(p.predictions, open(os.path.join(preds_fol, f'predictions_{(func_call_counts+1) * training_args.eval_steps }.pkl'), 'wb')) | |
pickle.dump(p.label_ids, open(os.path.join(preds_fol, f'label_ids_{(func_call_counts+1) * training_args.eval_steps }.pkl'), 'wb')) | |
# func_call_counts += 1 | |
# print(func_call_counts) | |
# Collect the logits | |
print('Here we go!') | |
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
# Compute the logistic sigmoid | |
preds = expit(preds) | |
# METRIC 0: Compute P@1, P@3, P@5 | |
if 'precision' not in fbr.keys(): | |
top_values = [1,3,5] | |
# denoms = {k:0 for k in top_values} | |
tops = {k:0 for k in top_values} | |
i = 0 | |
preds_preck = None | |
print(p.label_ids) | |
for i, (logit, label) in enumerate(zip(p.predictions[0], p.label_ids)): | |
logit = torch.from_numpy(logit) | |
label = torch.from_numpy(label) | |
_, indexes = torch.topk(logit.float(), k = max(top_values)) | |
for val in top_values: | |
if preds_preck is None: | |
tops[val] += len([x for x in indexes[:val] if label[x]!=0]) | |
else: | |
tops[val] += len([x for x in preds_preck[i][indexes[:val]] if label[x]!=0]) | |
# denoms[val] += min(val, label.nonzero().shape[0]) | |
precisions_at_k = {k:v/((i+1)*k) for k,v in tops.items()} | |
# rprecisions_at_k = {k:v/denoms[v] for k,v in tops.items()} | |
print('Evaluation Result: precision@{} = {}'.format(top_values, precisions_at_k)) | |
# print('Evaluation Result: rprecision@{} = {}'.format(top_values, rprecisions_at_k)) | |
# p.predictions = p.predictions[0] | |
# p.label_ids = p.label_ids[0] | |
# METRIC 1: Compute accuracy | |
if 'accuracy' not in fbr.keys(): | |
performance = {} | |
for threshold in np.arange(0.1, 1, 0.1): | |
accuracy_preds = np.where(preds > threshold, 1, 0) | |
performance[threshold] = np.sum(p.label_ids == accuracy_preds) / accuracy_preds.size * 100 | |
# Choose the best threshold | |
best_threshold = max(performance, key=performance.get) | |
fbr['accuracy'] = best_threshold | |
accuracy = performance[best_threshold] | |
else: | |
accuracy_preds = np.where(preds > fbr['accuracy'], 1, 0) | |
accuracy = np.sum(p.label_ids == accuracy_preds) / accuracy_preds.size * 100 | |
# METRIC 2: Compute the subset accuracy | |
if 'subset_accuracy' not in fbr.keys(): | |
performance = {} | |
for threshold in np.arange(0.1, 1, 0.1): | |
subset_accuracy_preds = np.where(preds > threshold, 1, 0) | |
performance[threshold] = accuracy_score(p.label_ids, subset_accuracy_preds) | |
# Choose the best threshold | |
best_threshold = max(performance, key=performance.get) | |
fbr['subset_accuracy'] = best_threshold | |
subset_accuracy = performance[best_threshold] | |
else: | |
subset_accuracy_preds = np.where(preds > fbr['subset_accuracy'], 1, 0) | |
subset_accuracy = accuracy_score(p.label_ids, subset_accuracy_preds) | |
# METRIC 3: Macro F-1 | |
if 'macro_f1' not in fbr.keys(): | |
performance = {} | |
for threshold in np.arange(0.1, 1, 0.1): | |
macro_f1_preds = np.where(preds > threshold, 1, 0) | |
performance[threshold] = f1_score(p.label_ids, macro_f1_preds, average='macro') | |
# Choose the best threshold | |
best_threshold = max(performance, key=performance.get) | |
fbr['macro_f1'] = best_threshold | |
macro_f1 = performance[best_threshold] | |
else: | |
macro_f1_preds = np.where(preds > fbr['macro_f1'], 1, 0) | |
macro_f1 = f1_score(p.label_ids, macro_f1_preds, average='macro') | |
# METRIC 4: Micro F-1 | |
if 'micro_f1' not in fbr.keys(): | |
performance = {} | |
for threshold in np.arange(0.1, 1, 0.1): | |
micro_f1_preds = np.where(preds > threshold, 1, 0) | |
performance[threshold] = f1_score(p.label_ids, micro_f1_preds, average='micro') | |
# Choose the best threshold | |
best_threshold = max(performance, key=performance.get) | |
fbr['micro_f1'] = best_threshold | |
micro_f1 = performance[best_threshold] | |
else: | |
micro_f1_preds = np.where(preds > fbr['micro_f1'], 1, 0) | |
micro_f1 = f1_score(p.label_ids, micro_f1_preds, average='micro') | |
# Multi-label classification report | |
# Optimized for Micro F-1 | |
try: | |
report = classification_report(p.label_ids, micro_f1_preds, target_names=[id2label[i] for i in range(len(id2label))]) | |
print('Classification Report: \n', report) | |
except: | |
report = classification_report(p.label_ids, micro_f1_preds) | |
print('Classification Report: \n', report) | |
return_dict = { | |
"accuracy": accuracy, | |
"subset_accuracy": subset_accuracy, | |
"macro_f1": macro_f1, | |
"micro_f1": micro_f1, | |
# "hier_micro_f1": hier_micro_f1, | |
"fbr": fbr | |
} | |
for k in precisions_at_k: | |
return_dict[f'P@{k}'] = precisions_at_k[k] | |
if training_args is not None and training_args.local_rank <= 0: | |
try: | |
metrics_fol = os.path.join(training_args.output_dir, 'metrics') | |
os.makedirs(metrics_fol, exist_ok = True) | |
json.dump(return_dict, open(os.path.join(metrics_fol, f'metrics_{(func_call_counts+1) * training_args.eval_steps }.json'), 'w'), indent = 2) | |
except Exception as e: | |
print('Error in metrics', e) | |
return return_dict | |
return compute_metrics |