|
import json |
|
from typing import List |
|
|
|
from adaptor.adapter import Adapter |
|
from adaptor.evaluators.generative import BLEU |
|
from adaptor.lang_module import LangModule |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
from datasets import load_dataset |
|
|
|
from priming_objective import Priming |
|
|
|
training_arguments = AdaptationArguments(output_dir="train_dir_SQuAD_random_large", |
|
learning_rate=2e-5, |
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
|
|
do_train=True, |
|
do_eval=True, |
|
warmup_steps=1000, |
|
max_steps=10000, |
|
gradient_accumulation_steps=30, |
|
eval_steps=500, |
|
logging_steps=10, |
|
save_steps=500, |
|
num_train_epochs=5, |
|
evaluation_strategy="steps", |
|
save_total_limit=10, |
|
stopping_patience=10) |
|
eval_examples = 200 |
|
|
|
|
|
num_demonstrations = 3 |
|
|
|
|
|
def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str: |
|
return " ".join(previous_examples + [current_example]) |
|
|
|
|
|
lang_module = LangModule("google/mt5-large") |
|
|
|
|
|
per_type_examples = {} |
|
|
|
qa_en = load_dataset("squad") |
|
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000) |
|
|
|
val_metrics = [BLEU(**{"additional_sep_char": "▁"})] |
|
|
|
|
|
|
|
|
|
def _get_en_qa_categories(data) -> List[str]: |
|
return [question.split()[0] if not question.startswith("To") |
|
else " ".join(question.split()[:2]) |
|
for question in data["question"]] |
|
|
|
|
|
q_answering_en = Priming(lang_module, |
|
max_eval_samples=eval_examples, |
|
demos_selection_strategy="random", |
|
texts_or_path=qa_train["question"], |
|
text_pair_or_path=qa_train["context"], |
|
val_texts_or_path=qa_en["validation"]["question"][-eval_examples:], |
|
val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:], |
|
labels_or_path=[a["text"][0] for a in qa_train["answers"]], |
|
val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:], |
|
train_question_categories=_get_en_qa_categories(qa_train), |
|
val_question_categories=_get_en_qa_categories(qa_en["validation"])[-eval_examples:], |
|
batch_size=1, |
|
val_evaluators=val_metrics, |
|
|
|
source_lang_id="en", |
|
objective_id="AQA-en") |
|
|
|
|
|
|
|
squad_cs_dataset = json.load(open("czech_squad_4-sents.json")) |
|
|
|
skipped = 0 |
|
|
|
questions_cs = [] |
|
contexts_cs = [] |
|
answers_cs = [] |
|
categories_cs = [] |
|
|
|
for i, entry in squad_cs_dataset.items(): |
|
if len(entry["context"]) > 800: |
|
skipped += 1 |
|
continue |
|
|
|
questions_cs.append(entry["question"]) |
|
contexts_cs.append(entry["context"]) |
|
answers_cs.append(entry["answers"]["text"][0]) |
|
categories_cs.append(entry["answer_type"]) |
|
|
|
print("Skipped cs examples: %s" % skipped) |
|
|
|
q_answering_cs = Priming(lang_module, |
|
max_eval_samples=eval_examples, |
|
demos_selection_strategy="random", |
|
texts_or_path=questions_cs[:-eval_examples], |
|
text_pair_or_path=contexts_cs[:-eval_examples], |
|
val_texts_or_path=questions_cs[-eval_examples:], |
|
val_text_pair_or_path=contexts_cs[-eval_examples:], |
|
labels_or_path=answers_cs[:-eval_examples], |
|
val_labels_or_path=answers_cs[-eval_examples:], |
|
train_question_categories=categories_cs[:-eval_examples], |
|
val_question_categories=categories_cs[-eval_examples:], |
|
batch_size=1, |
|
val_evaluators=val_metrics, |
|
source_lang_id="cs", |
|
objective_id="SQUAD-cs") |
|
|
|
schedule = ParallelSchedule(objectives=[q_answering_en, q_answering_cs], |
|
args=training_arguments) |
|
|
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
adapter.train() |
|
|