File size: 4,854 Bytes
32980f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import List

from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset

from priming_objective import Priming

training_arguments = AdaptationArguments(output_dir="train_dir_AQA_info_large_ru",
                                         learning_rate=2e-5,  # we set LR=2e-4 for pre-training experiments
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         # stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
                                         do_train=True,
                                         do_eval=True,
                                         warmup_steps=1000,
                                         max_steps=10000,
                                         gradient_accumulation_steps=30,
                                         eval_steps=500,
                                         logging_steps=10,
                                         save_steps=500,
                                         num_train_epochs=5,
                                         evaluation_strategy="steps",
                                         save_total_limit=10,
                                         stopping_patience=10)
eval_examples = 200

# priming
num_demonstrations = 3


def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
    return " ".join(previous_examples + [current_example])


lang_module = LangModule("google/mt5-large")

# priming
per_type_examples = {}

qa_en = load_dataset("adversarial_qa", "adversarialQA")
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)

val_metrics = [BLEU(**{"additional_sep_char": "▁"})]

# Adversarial QA dataset & objective:


def _get_firstword_categories(data) -> List[str]:
    return [question.split()[0] if not question.startswith("To")
            else " ".join(question.split()[:2])
            for question in data["question"]]


q_answering_en = Priming(lang_module,
                         max_eval_samples=eval_examples,
                         demos_selection_strategy="informative",
                         texts_or_path=qa_train["question"],
                         text_pair_or_path=qa_train["context"],
                         val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
                         val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
                         labels_or_path=[a["text"][0] for a in qa_train["answers"]],
                         val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
                         train_question_categories=_get_firstword_categories(qa_train),
                         val_question_categories=_get_firstword_categories(qa_en["validation"])[-eval_examples:],
                         batch_size=1,
                         val_evaluators=val_metrics,
                         # val_evaluators=val_metrics,
                         source_lang_id="en",
                         objective_id="AQA-en")

qa_ru = load_dataset("sberquad")
qa_ru_train = qa_ru["train"].filter(lambda entry: len(entry["context"]) < 800)


skipped = 0

q_answering_ru = Priming(lang_module,
                         max_eval_samples=eval_examples,
                         demos_selection_strategy="informative",
                         texts_or_path=qa_ru_train["question"],
                         text_pair_or_path=qa_ru_train["context"],
                         val_texts_or_path=qa_ru["validation"]["question"][-eval_examples:],
                         val_text_pair_or_path=qa_ru["validation"]["context"][-eval_examples:],
                         labels_or_path=[a["text"][0] for a in qa_ru_train["answers"]],
                         val_labels_or_path=[a["text"][0] for a in qa_ru["validation"]["answers"]][-eval_examples:],
                         train_question_categories=_get_firstword_categories(qa_ru_train),
                         val_question_categories=_get_firstword_categories(qa_ru["validation"])[-eval_examples:],
                         batch_size=1,
                         val_evaluators=val_metrics,
                         # val_evaluators=val_metrics,
                         source_lang_id="ru",
                         objective_id="SQuAD-ru")

schedule = ParallelSchedule(objectives=[q_answering_en,
                                        q_answering_ru
                                        ],
                            args=training_arguments)

adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()