|
from datasets import load_dataset |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional |
|
import random |
|
import matplotlib.pyplot as plt |
|
from score import calculate_gpt4o_scores, BENCHMARK_SCORES |
|
|
|
|
|
|
|
BENCHMARKS = { |
|
"icelandic-winogrande": { |
|
"name": "Winogrande", |
|
"path": "mideind/icelandic-winogrande", |
|
"type": "multiple_choice", |
|
}, |
|
"grammatical-error-detection": { |
|
"name": "Málfræðivillur", |
|
"path": "mideind/icelandic-sentences-gec", |
|
"type": "multiple_choice", |
|
}, |
|
"icelandic-inflection-all": { |
|
"name": "Fallbeygingarpróf", |
|
"path": "mideind/icelandic-inflection-all-flat", |
|
"type": "free_text", |
|
}, |
|
"icelandic-belebele": { |
|
"name": "Belebele", |
|
"path": "facebook/belebele", |
|
"config_name": "isl_Latn", |
|
"split": "test", |
|
"type": "multiple_choice", |
|
}, |
|
"icelandic-arc-challenge": { |
|
"name": "ARC Challenge", |
|
"path": "mideind/icelandic-arc-challenge", |
|
"type": "multiple_choice", |
|
}, |
|
"icelandic-wiki-qa": { |
|
"name": "Wikipediapróf", |
|
"path": "mideind/icelandic_wiki_qa", |
|
"type": "free_text", |
|
}, |
|
} |
|
|
|
DATASETS = { |
|
dataset_name: load_dataset( |
|
BENCHMARKS[dataset_name]["path"], |
|
name=BENCHMARKS[dataset_name].get("config_name"), |
|
split=BENCHMARKS[dataset_name].get("split", "train"), |
|
) |
|
for dataset_name in BENCHMARKS |
|
} |
|
|
|
|
|
|
|
def winogrande_preprocessing(sample): |
|
new_sample = {} |
|
new_sample["question"] = ( |
|
"Lestu eftirfarandi málsgrein:<p style='margin-left: 20px;'><i>{sentence}</i></p><br>Hvor valkostanna passar betur í eyðuna?".format( |
|
sentence=sample["sentence"].replace("_", "________") |
|
) |
|
) |
|
new_sample["options"] = sample["option1"], sample["option2"] |
|
new_sample["answer"] = ( |
|
sample["option1"] if sample["answer"] == "1" else sample["option2"] |
|
) |
|
new_sample["instruction"] = "Valkostir" |
|
return new_sample |
|
|
|
|
|
def icelandic_sentence_gec_preprocessing(sample): |
|
new_sample = {} |
|
new_sample["question"] = ( |
|
f"Inniheldur eftirfarandi málsgrein villu?<p style='margin-left: 25px;'><i>{sample['sentence']}</i></p>" |
|
) |
|
new_sample["options"] = "Villa", "Engin villa" |
|
new_sample["answer"] = "Engin villa" if sample["correct"] == "false" else "Villa" |
|
new_sample["instruction"] = "Valkostir" |
|
return new_sample |
|
|
|
|
|
def inflection_all_preprocessing(sample): |
|
new_sample = {} |
|
case_map = { |
|
"nf": "nefnifalli", |
|
"þf": "þolfalli", |
|
"þgf": "þágufalli", |
|
"ef": "eignarfalli", |
|
} |
|
plurality_map = {"et": "eintölu", "ft": "fleirtölu"} |
|
new_sample["question"] = ( |
|
f"Hvernig beygist <i>„{sample['noun_phrase']}“</i> í {case_map[sample['case']]} {plurality_map[sample['plurality']]}?" |
|
) |
|
new_sample["answer"] = sample["inflection"] |
|
new_sample["instruction"] = "Skrifaðu réttu beyginguna." |
|
return new_sample |
|
|
|
|
|
def belebele_preprocessing(sample): |
|
new_sample = {} |
|
new_sample["question"] = ( |
|
f'Lestu eftirfarandi texta:<p style="margin-left: 25px;"><i>{sample["flores_passage"]}</i></p>\n\n{sample["question"]}' |
|
) |
|
new_sample["options"] = [ |
|
sample["mc_answer1"], |
|
sample["mc_answer2"], |
|
sample["mc_answer3"], |
|
sample["mc_answer4"], |
|
] |
|
correct_idx = int(sample["correct_answer_num"]) - 1 |
|
new_sample["answer"] = new_sample["options"][correct_idx] |
|
new_sample["instruction"] = "Veldu réttasta svarið." |
|
return new_sample |
|
|
|
|
|
def arc_challenge_preprocessing(sample): |
|
new_sample = {} |
|
new_sample["question"] = sample["question"] |
|
new_sample["options"] = sample["choices"]["text"] |
|
correct_idx = sample["choices"]["label"].index(sample["answerKey"]) |
|
new_sample["answer"] = sample["choices"]["text"][correct_idx] |
|
new_sample["instruction"] = "Veldu réttasta svarið." |
|
return new_sample |
|
|
|
|
|
def wikipedia_preprocessing(sample): |
|
new_sample = {} |
|
new_sample["question"] = sample["query"] |
|
new_sample["answer"] = sample["answer"] |
|
new_sample["instruction"] = "Skrifaðu svarið þitt að neðan." |
|
return new_sample |
|
|
|
|
|
@dataclass |
|
class QuizState: |
|
benchmark_name: str |
|
samples: List[Dict[str, Any]] |
|
current_question: int |
|
user_answers: List[Optional[str]] |
|
correct_answers: List[str] |
|
quiz_completed: bool |
|
user_scores: List[Optional[float]] |
|
|
|
|
|
@dataclass |
|
class QuestionData: |
|
question_num: str |
|
question: str |
|
options: Optional[List[str]] |
|
answer: Optional[str] |
|
next_button_text: str |
|
previous_button_visibility: bool |
|
instruction: str = "" |
|
|
|
|
|
class BenchmarkQuiz: |
|
def __init__(self): |
|
self.state = None |
|
|
|
def start_quiz(self, benchmark_name: str) -> QuizState: |
|
samples = self.load_benchmark(benchmark_name) |
|
correct_answers = [sample["answer"] for sample in samples] |
|
self.state = QuizState( |
|
benchmark_name=benchmark_name, |
|
samples=samples, |
|
current_question=0, |
|
user_answers=[None] * len(samples), |
|
correct_answers=correct_answers, |
|
quiz_completed=False, |
|
user_scores=[None] * len(samples), |
|
) |
|
return self.state |
|
|
|
def load_benchmark(self, benchmark_name: str) -> List[Dict[str, Any]]: |
|
dataset = DATASETS[benchmark_name] |
|
random_indices = random.sample(range(len(dataset)), 5) |
|
samples = dataset.select(random_indices) |
|
if benchmark_name == "icelandic-winogrande": |
|
samples = [winogrande_preprocessing(sample) for sample in samples] |
|
elif benchmark_name == "grammatical-error-detection": |
|
samples = [ |
|
icelandic_sentence_gec_preprocessing(sample) for sample in samples |
|
] |
|
elif benchmark_name == "icelandic-inflection-all": |
|
samples = [inflection_all_preprocessing(sample) for sample in samples] |
|
elif benchmark_name == "icelandic-belebele": |
|
samples = [belebele_preprocessing(sample) for sample in samples] |
|
elif benchmark_name == "icelandic-arc-challenge": |
|
samples = [arc_challenge_preprocessing(sample) for sample in samples] |
|
elif benchmark_name == "icelandic-wiki-qa": |
|
samples = [wikipedia_preprocessing(sample) for sample in samples] |
|
return samples |
|
|
|
def update_question(self) -> QuestionData: |
|
""" |
|
Update the question data based on the current state. |
|
Is called when the user navigates to a new question. |
|
""" |
|
current_question = self.state.current_question |
|
sample = self.state.samples[current_question] |
|
|
|
question_num = ( |
|
f"### Spurning {current_question + 1} af {len(self.state.samples)}" |
|
) |
|
question = sample["question"] |
|
options = sample.get("options") |
|
answer = self.state.user_answers[current_question] |
|
next_button_text = ( |
|
"Klára" if current_question == len(self.state.samples) - 1 else "Næsta" |
|
) |
|
previous_button_visibility = current_question > 0 |
|
instruction = sample.get("instruction", "") |
|
|
|
return QuestionData( |
|
question_num=question_num, |
|
question=question, |
|
options=options, |
|
answer=answer, |
|
next_button_text=next_button_text, |
|
previous_button_visibility=previous_button_visibility, |
|
instruction=instruction, |
|
) |
|
|
|
def next_question(self, answer: str) -> Dict[str, Any]: |
|
""" |
|
Update the state with the user's answer to the current question. |
|
If the quiz is not completed, return the next question data. |
|
If the quiz is completed, return the score plot. |
|
Is called when the user submits an answer. |
|
""" |
|
self.state.user_answers[self.state.current_question] = answer |
|
if self.state.current_question < len(self.state.samples) - 1: |
|
self.state.current_question += 1 |
|
return {"completed": False, "question_data": self.update_question()} |
|
else: |
|
self.state.quiz_completed = True |
|
user_scores = self.calculate_scores() |
|
self.state.user_scores = user_scores |
|
plot = self.plot_score(user_scores) |
|
return { |
|
"completed": True, |
|
"plot": plot, |
|
"results_data": self.get_results_data(), |
|
} |
|
|
|
def previous_question(self) -> QuestionData: |
|
if self.state.current_question > 0: |
|
self.state.current_question -= 1 |
|
return self.update_question() |
|
|
|
def calculate_scores(self) -> list[float]: |
|
if self.state.benchmark_name == "icelandic-wiki-qa": |
|
queries = [sample["question"] for sample in self.state.samples] |
|
return calculate_gpt4o_scores( |
|
queries, self.state.user_answers, self.state.correct_answers |
|
) |
|
|
|
scores = [ |
|
float(user_answer == correct_answer) |
|
for user_answer, correct_answer in zip( |
|
self.state.user_answers, self.state.correct_answers |
|
) |
|
] |
|
return scores |
|
|
|
def plot_score(self, user_scores: List[float]): |
|
user_score = sum(user_scores) / len(user_scores) |
|
scores = {**BENCHMARK_SCORES[self.state.benchmark_name], "Þú": 100 * user_score} |
|
|
|
scores = dict(sorted(scores.items(), key=lambda item: item[1])) |
|
|
|
|
|
colors = {name: "tab:blue" for name in scores.keys()} |
|
colors["Þú"] = "tab:green" |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6), dpi=250) |
|
ax.spines[["left", "top", "right"]].set_visible(False) |
|
|
|
ax.barh( |
|
scores.keys(), |
|
scores.values(), |
|
height=0.6, |
|
color=[colors[name] for name in scores.keys()], |
|
) |
|
ax.set_axisbelow(True) |
|
ax.xaxis.grid(True, linestyle="--", alpha=0.6) |
|
ax.set_title( |
|
f"{BENCHMARKS[self.state.benchmark_name]['name']}: Svona stóðstu þig miðað við mállíkönin", |
|
pad=20, |
|
) |
|
ax.set_xlabel("Stig (%)") |
|
ax.set_xlim(0, 100) |
|
plt.tight_layout() |
|
return fig |
|
|
|
def get_results_data(self) -> List[Dict[str, Any]]: |
|
return [ |
|
{ |
|
"question_num": i + 1, |
|
"question": sample["question"], |
|
"user_answer": user_answer, |
|
"correct_answer": correct_answer, |
|
"options": sample.get("options"), |
|
"instruction": sample.get("instruction", ""), |
|
"points": score, |
|
} |
|
for i, (sample, user_answer, correct_answer, score) in enumerate( |
|
zip( |
|
self.state.samples, |
|
self.state.user_answers, |
|
self.state.correct_answers, |
|
self.state.user_scores, |
|
) |
|
) |
|
] |
|
|