michal-stefanik commited on
Commit
17fa1b9
1 Parent(s): 08afa8d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +194 -0
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - uva-irlab/canard_quretec
4
+ - hotpot_qa
5
+ ---
6
+
7
+ # Model Card for T5-LM-Large_Canard-HotpotQA-rephrase
8
+
9
+ This model is trained on three objectives: (1) Generating answers for Canard dataset, (2) Generating answers for HotpotQA, (3) Rephrasing questions by the previous conversations of Canard.
10
+
11
+ ## Training
12
+
13
+ The model was trained using the following script, exported from the corresponding Jupyter notebook. All details, including the request format, can be inferred without errors from the code.
14
+ The best checkpoint was picked by a minimal loss on all (3) training objectives.
15
+
16
+ ```python
17
+ import datasets
18
+ canard_train_augm = datasets.load_from_disk("canard_train_augm_full.hf") # constructed in notebook: 2.1_construct_qa_dataset.ipynb
19
+ canard_test_augm = datasets.load_from_disk("canard_test_augm_full.hf")
20
+
21
+ canard_df = canard_train_augm.to_pandas()
22
+ canard_test_df = canard_train_augm.to_pandas()
23
+
24
+ ### Curation of seq2seq input contexts and labels
25
+ import random
26
+
27
+ def input_context_from_sample(row: dict, max_length=5) -> str:
28
+ context = "Previous conversation:"
29
+ context += "\nQuestion: "
30
+ context += ", ".join(row["History"][:3])
31
+ for i in range(3, len(row["History"]), 2):
32
+ context += "\nAnswer: "
33
+ context += row["History"][i]
34
+ if i+1 < len(row["History"]):
35
+ context += "\nQuestion: "
36
+ context += row["History"][i+1]
37
+
38
+ context += "\n\nCurrent Question: "
39
+ context += row["Question"]
40
+
41
+ context += "\nSearch results:"
42
+ all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
43
+ random.shuffle(all_contexts)
44
+
45
+ for i, search_result in enumerate(all_contexts):
46
+ context += "\n[%s]: " % (i+1)
47
+ context += search_result.replace("CANNOTANSWER", "")
48
+
49
+ context += "\nCurrent Answer: "
50
+ return context
51
+
52
+
53
+ def rephrasing_context_from_sample(row: dict) -> str:
54
+ context = "Previous conversation:"
55
+ context += "\nQuestion: "
56
+ context += ", ".join(row["History"][:3])
57
+ for i in range(3, len(row["History"]), 2):
58
+ context += "\nAnswer: "
59
+ context += row["History"][i]
60
+ if i+1 < len(row["History"]):
61
+ context += "\nQuestion: "
62
+ context += row["History"][i+1]
63
+
64
+ context += "\n\nCurrent Question: "
65
+ context += row["Question"]
66
+
67
+ context += "\nMore specific question: "
68
+ return context
69
+
70
+
71
+ def hotpotqa_context(row: dict) -> str:
72
+ context = "Current Question: "
73
+ context += row["question"]
74
+
75
+ context += "\nSearch results:"
76
+ all_contexts = [" ".join(context) for context in row["context"]["sentences"]]
77
+
78
+ for i, search_result in enumerate(all_contexts):
79
+ context += "\n[%s]: " % (i+1)
80
+ # context += search_result.replace("CANNOTANSWER", "")
81
+
82
+ context += "\nCurrent Answer: "
83
+ return context
84
+
85
+
86
+ input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
87
+ input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values
88
+
89
+ too_long_index = [len(t) > 20000 for t in input_texts]
90
+ input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
91
+ print("training on %s samples" % len(input_texts))
92
+
93
+ labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
94
+ labels = [l for i, l in enumerate(labels) if not too_long_index[i]]
95
+
96
+ val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
97
+
98
+ rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
99
+ print(rephrasing_inputs[0])
100
+
101
+ rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
102
+
103
+ rephrasing_labels = canard_df.Rewrite.values
104
+ rephrasing_val_labels = canard_test_df.Rewrite.values
105
+ print(rephrasing_labels[0])
106
+
107
+ # Training
108
+ from adaptor.lang_module import LangModule
109
+
110
+ lang_module = LangModule("google/t5-large-lm-adapt")
111
+
112
+ from adaptor.evaluators.generative import ROUGE, BLEU
113
+
114
+ evaluators = [BLEU(), ROUGE()]
115
+
116
+ from adaptor.objectives.seq2seq import Sequence2Sequence
117
+
118
+ seq_qa = Sequence2Sequence(lang_module,
119
+ texts_or_path=input_texts,
120
+ labels_or_path=labels,
121
+ val_texts_or_path=input_val_texts,
122
+ val_labels_or_path=val_labels,
123
+ batch_size=4,
124
+ val_evaluators=evaluators,
125
+ objective_id="Canard")
126
+
127
+ hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
128
+ hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]
129
+
130
+ hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
131
+ hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
132
+
133
+ too_long_index = [len(t) > 20000 for t in hotpot_inputs]
134
+
135
+ hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
136
+ hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]
137
+
138
+ seq_additional_qa = Sequence2Sequence(lang_module,
139
+ texts_or_path=hotpot_inputs,
140
+ labels_or_path=hotpot_answers,
141
+ val_texts_or_path=hotpot_val_inputs[:200],
142
+ val_labels_or_path=hotpot_val["answer"][:200],
143
+ batch_size=4,
144
+ val_evaluators=evaluators,
145
+ objective_id="HotpotQA",
146
+ share_other_objective_head=seq_qa)
147
+
148
+
149
+ seq_rephrasing = Sequence2Sequence(lang_module,
150
+ texts_or_path=rephrasing_inputs,
151
+ labels_or_path=rephrasing_labels,
152
+ val_texts_or_path=rephrasing_val_inputs[:200],
153
+ val_labels_or_path=rephrasing_val_labels[:200],
154
+ batch_size=4,
155
+ val_evaluators=evaluators,
156
+ objective_id="rephrasing",
157
+ share_other_objective_head=seq_qa)
158
+
159
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
160
+
161
+ training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
162
+ learning_rate=5e-5,
163
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
164
+ stopping_patience=8,
165
+ save_total_limit=8,
166
+ do_train=True,
167
+ do_eval=True,
168
+ bf16=True,
169
+ warmup_steps=1000,
170
+ gradient_accumulation_steps=8,
171
+ logging_steps=10,
172
+ eval_steps=200,
173
+ save_steps=1000,
174
+ num_train_epochs=10,
175
+ evaluation_strategy="steps")
176
+
177
+ from adaptor.schedules import ParallelSchedule
178
+ from adaptor.adapter import Adapter
179
+
180
+ schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
181
+ args=training_arguments)
182
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
183
+
184
+ adapter.train()
185
+ ```
186
+
187
+
188
+ ## Usage
189
+
190
+ See the prompting templates used in training to infer the optimal prompting format.
191
+
192
+ #### Contact
193
+
194
+ Feel free to ask questions at stefanik{at} gaussalgo.com