resrer-pegasus-x / eval.py
seonglae's picture
Training in progress, step 500
a9082f6
raw
history blame contribute delete
No virus
2.67 kB
import re
import string
import unicodedata
from evaluate import evaluator, QuestionAnsweringEvaluator
from datasets import load_dataset
def evaluate_dataset(id: str, subset: str, metric: str = 'squad_v2',
question_col: str = 'question', context_col: str = 'retrieved', predict_col: str = 'predicted',
id_col: str = 'question', label_col: str = 'answer', labeling: bool = True):
referee: QuestionAnsweringEvaluator = evaluator("question-answering")
referee.PIPELINE_KWARGS["handle_impossible_answer"] = True
# Dataset
dataset = load_dataset(id, subset)
dataset_list = list(dataset['train'])
metric_input, qa = referee.prepare_data(
dataset['train'], question_col, context_col, id_col, label_col)
# References
if labeling:
for i, reference in enumerate(metric_input['references']):
starts = [qa['context'][i].find(answer)
for answer in reference['answers']]
reference['answers'] = {
'answer_start': starts, 'text': reference['answers']}
# Prediction
metric_input['predictions'] = []
for row in dataset_list:
result = {
'prediction_text': row[predict_col], 'id': row[id_col]}
if metric == 'squad_v2':
result['no_answer_probability'] = 0.
metric_input['predictions'].append(result)
metric_module = referee.prepare_metric(metric)
results = referee.compute_metric(metric_module, metric_inputs=metric_input)
return results
def evaluate_dataset_manual(id: str, subset: str):
dataset = load_dataset(id, subset)
dataset_list = list(dataset['train'])
for row in dataset_list:
row['score'] = max([regex_match_score(row['predicted'], answer)
for answer in row['answer']])
score = sum([row['score'] for row in dataset_list]) / len(dataset_list)
return score
def normalize_answer(s):
"""Normalize answer."""
s = unicodedata.normalize("NFD", s)
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def regex_match_score(prediction, ground_truth):
try:
regex = re.compile(ground_truth,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE)
return regex.match(prediction) is not None
except re.error:
return False