Shaltiel's picture
Removed xml tags from heq score
5d21832
raw
history blame
4.18 kB
import re
import string
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.metrics import Metrics, MetricCategory
from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase
from aenum import extend_enum
import numpy as np
from lighteval.tasks.requests import Doc
from Levenshtein import distance
import collections
from lighteval.utils import as_list
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)
def normalize_answer(s):
def remove_articles(text):
return ARTICLES_REGEX.sub(" ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s.replace('<pad>', '').replace('</s>', '').strip()))))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def normalized_edit_similarity(p1, p2):
return 1-distance(p1, p2)/ max(len(p1), len(p2))
def compute_token_edit(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
num_same = sum([max([normalized_edit_similarity(gold_t, pred_t) for pred_t in pred_toks]) for gold_t in gold_toks])
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def tlnls(a_gold, a_pred):
digit_count = sum(1 for char in a_pred if char.isdigit())
if digit_count < len(a_pred) / 2:
return compute_token_edit(a_gold, a_pred)
else:
return compute_f1(a_gold, a_pred)
def heq_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None):
if len(predictions) > 1:
raise ValueError("Predictions should have one item")
pred = re.sub('<[^>]+>', '', predictions[0]) # remove xml tags
return max([tlnls(x, pred) for x in golds])
heq_tlnls_metric = CorpusLevelMetric(
metric="heq_tlnls",
higher_is_better=True,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
sample_level_fn=heq_eval_fn
)
extend_enum(Metrics, 'heq_tlnls_metric', heq_tlnls_metric)
def heq_prompt_fn(line, task_name: str = None):
"""Defines how to go from a dataset line to a doc object.
Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info
about what this function should do in the README.
"""
return Doc(
task_name=task_name,
query=line["prompt"],
choices=line["response"],
gold_index=list(range(len(line["response"]))),
instruction="",
)
# This is how you create a simple tasks (like hellaswag) which has one single subset
# attached to it, and one evaluation possible.
heq_task = LightevalTaskConfig(
name="heq-qa-tlnls",
prompt_function="heq_prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
suite=["custom"],
hf_repo="dicta-hebrew-llm-leaderboard/tests",
hf_subset="default",
hf_avail_splits=["heq"],
evaluation_splits=["heq"],
metric=['heq_tlnls_metric'],
stop_sequence=['\n'],
generation_size=64
)
heq_task.stop_sequence = as_list(heq_task.stop_sequence)