Text2Text Generation
Transformers
PyTorch
Safetensors
Czech
English
mt5
Inference Endpoints
mTk-SQuAD_en-SQAD_cs-1B / train_mt5_qa_en_SQuAD+cs_random.py
michal-stefanik's picture
README & Training resources
289fb31
raw
history blame
4.94 kB
import json
from typing import List
from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset
from priming_objective import Priming
training_arguments = AdaptationArguments(output_dir="train_dir_SQuAD_random_large",
learning_rate=2e-5, # we set LR=2e-4 for pre-training experiments
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
# stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
do_train=True,
do_eval=True,
warmup_steps=1000,
max_steps=10000,
gradient_accumulation_steps=30,
eval_steps=500,
logging_steps=10,
save_steps=500,
num_train_epochs=5,
evaluation_strategy="steps",
save_total_limit=10,
stopping_patience=10)
eval_examples = 200
# priming
num_demonstrations = 3
def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
return " ".join(previous_examples + [current_example])
lang_module = LangModule("google/mt5-large")
# priming
per_type_examples = {}
qa_en = load_dataset("squad")
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
val_metrics = [BLEU(**{"additional_sep_char": "▁"})]
# SQuAD QA dataset & objective:
def _get_en_qa_categories(data) -> List[str]:
return [question.split()[0] if not question.startswith("To")
else " ".join(question.split()[:2])
for question in data["question"]]
q_answering_en = Priming(lang_module,
max_eval_samples=eval_examples,
demos_selection_strategy="random",
texts_or_path=qa_train["question"],
text_pair_or_path=qa_train["context"],
val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
labels_or_path=[a["text"][0] for a in qa_train["answers"]],
val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
train_question_categories=_get_en_qa_categories(qa_train),
val_question_categories=_get_en_qa_categories(qa_en["validation"])[-eval_examples:],
batch_size=1,
val_evaluators=val_metrics,
# val_evaluators=val_metrics,
source_lang_id="en",
objective_id="AQA-en")
# Czech data & objective
squad_cs_dataset = json.load(open("czech_squad_4-sents.json"))
skipped = 0
questions_cs = []
contexts_cs = []
answers_cs = []
categories_cs = []
for i, entry in squad_cs_dataset.items():
if len(entry["context"]) > 800:
skipped += 1
continue
questions_cs.append(entry["question"])
contexts_cs.append(entry["context"])
answers_cs.append(entry["answers"]["text"][0])
categories_cs.append(entry["answer_type"])
print("Skipped cs examples: %s" % skipped)
q_answering_cs = Priming(lang_module,
max_eval_samples=eval_examples,
demos_selection_strategy="random",
texts_or_path=questions_cs[:-eval_examples],
text_pair_or_path=contexts_cs[:-eval_examples],
val_texts_or_path=questions_cs[-eval_examples:],
val_text_pair_or_path=contexts_cs[-eval_examples:],
labels_or_path=answers_cs[:-eval_examples],
val_labels_or_path=answers_cs[-eval_examples:],
train_question_categories=categories_cs[:-eval_examples],
val_question_categories=categories_cs[-eval_examples:],
batch_size=1,
val_evaluators=val_metrics,
source_lang_id="cs",
objective_id="SQUAD-cs")
schedule = ParallelSchedule(objectives=[q_answering_en, q_answering_cs],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()