|
|
|
|
|
from adaptor.objectives.question_answering import ExtractiveQA |
|
import json |
|
|
|
from adaptor.adapter import Adapter |
|
from adaptor.evaluators.question_answering import BLEUForQA |
|
from adaptor.lang_module import LangModule |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
model_name = "bert-base-multilingual-cased" |
|
|
|
lang_module = LangModule(model_name) |
|
|
|
training_arguments = AdaptationArguments(output_dir="train_dir", |
|
learning_rate=4e-5, |
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
do_train=True, |
|
do_eval=True, |
|
warmup_steps=1000, |
|
max_steps=100000, |
|
gradient_accumulation_steps=1, |
|
eval_steps=1, |
|
logging_steps=10, |
|
save_steps=1000, |
|
num_train_epochs=30, |
|
evaluation_strategy="steps") |
|
|
|
val_metrics = [BLEUForQA(decides_convergence=True)] |
|
|
|
|
|
squad_dataset = json.load(open("data/czech_squad.json")) |
|
questions = [] |
|
contexts = [] |
|
answers = [] |
|
skipped = 0 |
|
|
|
for i, entry in squad_dataset.items(): |
|
if entry["answers"]["text"][0] in entry["context"]: |
|
|
|
questions.append(entry["question"]) |
|
contexts.append(entry["context"]) |
|
answers.append(entry["answers"]["text"][0]) |
|
else: |
|
skipped += 1 |
|
|
|
print("Skipped examples from SQuAD-cs: %s" % skipped) |
|
|
|
train_questions = questions[:-200] |
|
val_questions = questions[-200:] |
|
|
|
train_answers = answers[:-200] |
|
val_answers = answers[-200:] |
|
|
|
train_context = contexts[:-200] |
|
val_context = contexts[-200:] |
|
|
|
|
|
generative_qa_cs = ExtractiveQA(lang_module, |
|
texts_or_path=train_questions, |
|
text_pair_or_path=train_context, |
|
labels_or_path=train_answers, |
|
val_texts_or_path=val_questions, |
|
val_text_pair_or_path=val_context, |
|
val_labels_or_path=val_answers, |
|
batch_size=1, |
|
val_evaluators=val_metrics, |
|
objective_id="SQUAD-cs") |
|
|
|
|
|
squad_en = load_dataset("squad") |
|
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000) |
|
|
|
train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"], |
|
squad_train["context"])] |
|
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"], |
|
squad_en["validation"]["context"])] |
|
train_answers_en = [a["text"][0] for a in squad_train["answers"]] |
|
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]] |
|
|
|
generative_qa_en = ExtractiveQA(lang_module, |
|
texts_or_path=squad_train["question"], |
|
text_pair_or_path=squad_train["context"], |
|
labels_or_path=[a["text"][0] for a in squad_train["answers"]], |
|
val_texts_or_path=squad_en["validation"]["question"][:200], |
|
val_text_pair_or_path=squad_en["validation"]["context"][:200], |
|
val_labels_or_path=[a["text"][0] for a in squad_en["validation"]["answers"]][:200], |
|
batch_size=10, |
|
val_evaluators=val_metrics, |
|
objective_id="SQUAD-en") |
|
|
|
schedule = ParallelSchedule(objectives=[generative_qa_cs, generative_qa_en], |
|
args=training_arguments) |
|
|
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
adapter.train() |
|
|