leaderboard / scorer.py
Clémentine
Updated system to connect the different repos
3d87820
raw
history blame
No virus
2.37 kB
import json
import re
import string
import numpy as np
def normalize_text(text: str) -> str:
"From QuAC"
def remove_articles(text: str) -> str:
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text: str) -> str:
return " ".join(text.split())
def homogeneize_numbers(text: str) -> str:
try:
return str(float(text))
except ValueError:
return text
def remove_punc(text: str) -> str:
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def remove_punc2(text: str) -> str:
"From Grégoire's code, removes all punctuation, nicer than remove_punc"
translator = str.maketrans('', '', string.punctuation)
return text.translate(translator)
def lower(text: str) -> str:
return text.lower()
def _tokenize(text):
return re.split(" ", text)
tokens = [white_space_fix(remove_articles(homogeneize_numbers(remove_punc2(lower(t))))) for t in _tokenize(text)]
return " ".join([t for t in tokens if t != ""]).strip()
def extract_answer(input_str: str, prompt_sep: str = 'FINAL ANSWER: ') -> str:
answer = input_str.split(prompt_sep)[-1].strip()
return answer
def extract_bow(input_str: str) -> list[str]:
return input_str.split(" ")
def numbers_equals_in_bow(gold_list: list, pred_list: list) -> bool:
# Numbers in prediction bag of words
pred_numbers = []
for text in pred_list:
try:
pred_numbers.append(str(float(text)))
except ValueError:
continue
for text in gold_list:
try:
number = str(float(text))
if number not in pred_numbers:
return False
except ValueError:
continue
return True
def affix_quasi_exact_match(gold: str, pred: str) -> float:
if not pred:
return 0
normalized_pred = normalize_text(pred)
normalized_gold = normalize_text(gold)
bow_pred = extract_bow(pred)
bow_gold = extract_bow(gold)
if normalized_pred.startswith(normalized_gold) or normalized_pred.endswith(normalized_gold):
if numbers_equals_in_bow(bow_gold, bow_pred):
return 1
return 0
def question_scorer(gold: str, pred: str) -> float:
return affix_quasi_exact_match(gold, pred)