File size: 1,140 Bytes
5a69a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from collections import Counter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id2label = {0: "none", 1:"title", 2:"content", 3:"author", 4: "date", 5: "header", 6: "footer", 7: "rail", 8: "advertisement", 9: "navigation"}
label2id = {label:id for id, label in id2label.items()}
label_list = ["B-" + x for x in list(id2label.values())]
def get_labels(predictions, references):
# Transform predictions and references tensos to numpy arrays
if device.type == "cpu":
y_pred = predictions.detach().clone().numpy()
y_true = references.detach().clone().numpy()
else:
y_pred = predictions.detach().cpu().clone().numpy()
y_true = references.detach().cpu().clone().numpy()
# Remove ignored index (special tokens)
true_predictions = [
[label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
for pred, gold_label in zip(y_pred, y_true)
]
true_labels = [
[label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
for pred, gold_label in zip(y_pred, y_true)
]
return true_predictions, true_labels |