|
import torch |
|
from collections import Counter |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
id2label = {0: "none", 1:"title", 2:"content", 3:"author", 4: "date", 5: "header", 6: "footer", 7: "rail", 8: "advertisement", 9: "navigation"} |
|
label2id = {label:id for id, label in id2label.items()} |
|
|
|
label_list = ["B-" + x for x in list(id2label.values())] |
|
|
|
def get_labels(predictions, references): |
|
|
|
if device.type == "cpu": |
|
y_pred = predictions.detach().clone().numpy() |
|
y_true = references.detach().clone().numpy() |
|
else: |
|
y_pred = predictions.detach().cpu().clone().numpy() |
|
y_true = references.detach().cpu().clone().numpy() |
|
|
|
|
|
true_predictions = [ |
|
[label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] |
|
for pred, gold_label in zip(y_pred, y_true) |
|
] |
|
true_labels = [ |
|
[label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] |
|
for pred, gold_label in zip(y_pred, y_true) |
|
] |
|
return true_predictions, true_labels |