|
from typing import List |
|
|
|
from adaptor.adapter import Adapter |
|
from adaptor.evaluators.generative import BLEU |
|
from adaptor.lang_module import LangModule |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
from datasets import load_dataset |
|
|
|
from priming_objective import Priming |
|
|
|
training_arguments = AdaptationArguments(output_dir="train_dir_AQA_info_large_ru", |
|
learning_rate=2e-5, |
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
|
|
do_train=True, |
|
do_eval=True, |
|
warmup_steps=1000, |
|
max_steps=10000, |
|
gradient_accumulation_steps=30, |
|
eval_steps=500, |
|
logging_steps=10, |
|
save_steps=500, |
|
num_train_epochs=5, |
|
evaluation_strategy="steps", |
|
save_total_limit=10, |
|
stopping_patience=10) |
|
eval_examples = 200 |
|
|
|
|
|
num_demonstrations = 3 |
|
|
|
|
|
def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str: |
|
return " ".join(previous_examples + [current_example]) |
|
|
|
|
|
lang_module = LangModule("google/mt5-large") |
|
|
|
|
|
per_type_examples = {} |
|
|
|
qa_en = load_dataset("adversarial_qa", "adversarialQA") |
|
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000) |
|
|
|
val_metrics = [BLEU(**{"additional_sep_char": "▁"})] |
|
|
|
|
|
|
|
|
|
def _get_firstword_categories(data) -> List[str]: |
|
return [question.split()[0] if not question.startswith("To") |
|
else " ".join(question.split()[:2]) |
|
for question in data["question"]] |
|
|
|
|
|
q_answering_en = Priming(lang_module, |
|
max_eval_samples=eval_examples, |
|
demos_selection_strategy="informative", |
|
texts_or_path=qa_train["question"], |
|
text_pair_or_path=qa_train["context"], |
|
val_texts_or_path=qa_en["validation"]["question"][-eval_examples:], |
|
val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:], |
|
labels_or_path=[a["text"][0] for a in qa_train["answers"]], |
|
val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:], |
|
train_question_categories=_get_firstword_categories(qa_train), |
|
val_question_categories=_get_firstword_categories(qa_en["validation"])[-eval_examples:], |
|
batch_size=1, |
|
val_evaluators=val_metrics, |
|
|
|
source_lang_id="en", |
|
objective_id="AQA-en") |
|
|
|
qa_ru = load_dataset("sberquad") |
|
qa_ru_train = qa_ru["train"].filter(lambda entry: len(entry["context"]) < 800) |
|
|
|
|
|
skipped = 0 |
|
|
|
q_answering_ru = Priming(lang_module, |
|
max_eval_samples=eval_examples, |
|
demos_selection_strategy="informative", |
|
texts_or_path=qa_ru_train["question"], |
|
text_pair_or_path=qa_ru_train["context"], |
|
val_texts_or_path=qa_ru["validation"]["question"][-eval_examples:], |
|
val_text_pair_or_path=qa_ru["validation"]["context"][-eval_examples:], |
|
labels_or_path=[a["text"][0] for a in qa_ru_train["answers"]], |
|
val_labels_or_path=[a["text"][0] for a in qa_ru["validation"]["answers"]][-eval_examples:], |
|
train_question_categories=_get_firstword_categories(qa_ru_train), |
|
val_question_categories=_get_firstword_categories(qa_ru["validation"])[-eval_examples:], |
|
batch_size=1, |
|
val_evaluators=val_metrics, |
|
|
|
source_lang_id="ru", |
|
objective_id="SQuAD-ru") |
|
|
|
schedule = ParallelSchedule(objectives=[q_answering_en, |
|
q_answering_ru |
|
], |
|
args=training_arguments) |
|
|
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
adapter.train() |
|
|