Detsutut's picture
Update evaluation.py
65e7536 verified
import evaluate
import torch
from enum import Enum
from tqdm import tqdm
class AssertionType(Enum):
PRESENT = 0
ABSENT = 1
POSSIBLE = 2
class EntityWithAssertion:
def __init__(self, entity: str, assertion_type: AssertionType):
self.entity = entity
self.assertion_type = assertion_type
def __repr__(self) -> str:
return f"{self.assertion_type.name}: {self.entity}"
def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32):
predictions = []
for i in tqdm(range(0, len(sentences), batch_size)):
batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda")
with torch.no_grad():
outputs = model(**batch)
predicted_labels = torch.argmax(outputs.logits, dim=1)
predictions.append(predicted_labels)
return torch.cat(predictions)
def input_classification(model, tokenizer, x: str = None, all_classes = False):
if x is None:
x = input("Write your sentence and press Enter to continue")
tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**tokenized_x)
predicted_label = torch.argmax(outputs.logits, dim=1)
if all_classes:
return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])}
return model.config.id2label[int(predicted_label)]
def compute_results(y, y_hat):
metric_f1 = evaluate.load("f1")
metric_acc = evaluate.load("accuracy")
return {
"macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"],
"micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"],
"accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"]
}