Spaces:
Paused
Paused
import re | |
import string | |
from lighteval.tasks.lighteval_task import LightevalTaskConfig | |
from lighteval.metrics import Metrics, MetricCategory | |
from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase | |
from aenum import extend_enum | |
import numpy as np | |
from lighteval.tasks.requests import Doc | |
from Levenshtein import distance | |
import collections | |
from lighteval.utils import as_list | |
def get_tokens(s): | |
if not s: | |
return [] | |
return normalize_answer(s).split() | |
ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE) | |
def normalize_answer(s): | |
def remove_articles(text): | |
return ARTICLES_REGEX.sub(" ", text) | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return "".join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s.replace('<pad>', '').replace('</s>', '').strip())))) | |
def compute_f1(a_gold, a_pred): | |
gold_toks = get_tokens(a_gold) | |
pred_toks = get_tokens(a_pred) | |
common = collections.Counter(gold_toks) & collections.Counter(pred_toks) | |
num_same = sum(common.values()) | |
if len(gold_toks) == 0 or len(pred_toks) == 0: | |
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise | |
return int(gold_toks == pred_toks) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(pred_toks) | |
recall = 1.0 * num_same / len(gold_toks) | |
f1 = (2 * precision * recall) / (precision + recall) | |
return f1 | |
def normalized_edit_similarity(p1, p2): | |
return 1-distance(p1, p2)/ max(len(p1), len(p2)) | |
def compute_token_edit(a_gold, a_pred): | |
gold_toks = get_tokens(a_gold) | |
pred_toks = get_tokens(a_pred) | |
if len(gold_toks) == 0 or len(pred_toks) == 0: | |
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise | |
return int(gold_toks == pred_toks) | |
num_same = sum([max([normalized_edit_similarity(gold_t, pred_t) for pred_t in pred_toks]) for gold_t in gold_toks]) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(pred_toks) | |
recall = 1.0 * num_same / len(gold_toks) | |
f1 = (2 * precision * recall) / (precision + recall) | |
return f1 | |
def tlnls(a_gold, a_pred): | |
digit_count = sum(1 for char in a_pred if char.isdigit()) | |
if digit_count < len(a_pred) / 2: | |
return compute_token_edit(a_gold, a_pred) | |
else: | |
return compute_f1(a_gold, a_pred) | |
def heq_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None): | |
if len(predictions) > 1: | |
raise ValueError("Predictions should have one item") | |
pred = re.sub('<[^>]+>', '', predictions[0]) # remove xml tags | |
return max([tlnls(x, pred) for x in golds]) | |
heq_tlnls_metric = CorpusLevelMetric( | |
metric="heq_tlnls", | |
higher_is_better=True, | |
category=MetricCategory.GENERATIVE, | |
use_case=MetricUseCase.ACCURACY, | |
corpus_level_fn=np.mean, | |
sample_level_fn=heq_eval_fn | |
) | |
extend_enum(Metrics, 'heq_tlnls_metric', heq_tlnls_metric) | |
def heq_prompt_fn(line, task_name: str = None): | |
"""Defines how to go from a dataset line to a doc object. | |
Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info | |
about what this function should do in the README. | |
""" | |
return Doc( | |
task_name=task_name, | |
query=line["prompt"], | |
choices=line["response"], | |
gold_index=list(range(len(line["response"]))), | |
instruction="", | |
) | |
# This is how you create a simple tasks (like hellaswag) which has one single subset | |
# attached to it, and one evaluation possible. | |
heq_task = LightevalTaskConfig( | |
name="heq-qa-tlnls", | |
prompt_function="heq_prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py | |
suite=["custom"], | |
hf_repo="dicta-hebrew-llm-leaderboard/tests", | |
hf_subset="default", | |
hf_avail_splits=["heq"], | |
evaluation_splits=["heq"], | |
metric=['heq_tlnls_metric'], | |
stop_sequence=['\n'], | |
generation_size=64 | |
) | |
heq_task.stop_sequence = as_list(heq_task.stop_sequence) |