xlm-roberta-large_extractive-QA_en-cs / train_roberta_extractive_qa.py
michal-stefanik's picture
Upload train_roberta_extractive_qa.py
36e28a3
raw
history blame
4.57 kB
# TODO: BEFORE RUNNING: pip install git+https://github.com/gaussalgo/adaptor.git@QA_to_objectives
from adaptor.objectives.question_answering import ExtractiveQA
import json
from adaptor.adapter import Adapter
from adaptor.evaluators.question_answering import BLEUForQA
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
# custom classes
from datasets import load_dataset
model_name = "bert-base-multilingual-cased"
lang_module = LangModule(model_name)
training_arguments = AdaptationArguments(output_dir="train_dir",
learning_rate=4e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
do_train=True,
do_eval=True,
warmup_steps=1000,
max_steps=100000,
gradient_accumulation_steps=1,
eval_steps=1,
logging_steps=10,
save_steps=1000,
num_train_epochs=30,
evaluation_strategy="steps")
val_metrics = [BLEUForQA(decides_convergence=True)]
# get eval and train dataset
squad_dataset = json.load(open("data/czech_squad.json"))
questions = []
contexts = []
answers = []
skipped = 0
for i, entry in squad_dataset.items():
if entry["answers"]["text"][0] in entry["context"]:
# and len(entry["context"]) < 1024: # these are characters, will be automatically truncated from input anyway
questions.append(entry["question"])
contexts.append(entry["context"])
answers.append(entry["answers"]["text"][0])
else:
skipped += 1
print("Skipped examples from SQuAD-cs: %s" % skipped)
train_questions = questions[:-200]
val_questions = questions[-200:]
train_answers = answers[:-200]
val_answers = answers[-200:]
train_context = contexts[:-200]
val_context = contexts[-200:]
# declaration of extractive question answering objective
generative_qa_cs = ExtractiveQA(lang_module,
texts_or_path=train_questions,
text_pair_or_path=train_context,
labels_or_path=train_answers,
val_texts_or_path=val_questions,
val_text_pair_or_path=val_context,
val_labels_or_path=val_answers,
batch_size=1,
val_evaluators=val_metrics,
objective_id="SQUAD-cs")
# english SQuAD
squad_en = load_dataset("squad")
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
squad_train["context"])]
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
squad_en["validation"]["context"])]
train_answers_en = [a["text"][0] for a in squad_train["answers"]]
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]
generative_qa_en = ExtractiveQA(lang_module,
texts_or_path=squad_train["question"],
text_pair_or_path=squad_train["context"],
labels_or_path=[a["text"][0] for a in squad_train["answers"]],
val_texts_or_path=squad_en["validation"]["question"][:200],
val_text_pair_or_path=squad_en["validation"]["context"][:200],
val_labels_or_path=[a["text"][0] for a in squad_en["validation"]["answers"]][:200],
batch_size=10,
val_evaluators=val_metrics,
objective_id="SQUAD-en")
schedule = ParallelSchedule(objectives=[generative_qa_cs, generative_qa_en],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()