File size: 7,723 Bytes
17fa1b9 1f82ada 17fa1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
---
datasets:
- hotpot_qa
- canard
---
# Model Card for T5-LM-Large_Canard-HotpotQA-rephrase
This model is trained on three objectives: (1) Generating answers for Canard dataset, (2) Generating answers for HotpotQA, (3) Rephrasing questions by the previous conversations of Canard.
## Training
The model was trained using the following script, exported from the corresponding Jupyter notebook. All details, including the request format, can be inferred without errors from the code.
The best checkpoint was picked by a minimal loss on all (3) training objectives.
```python
import datasets
canard_train_augm = datasets.load_from_disk("canard_train_augm_full.hf") # constructed in notebook: 2.1_construct_qa_dataset.ipynb
canard_test_augm = datasets.load_from_disk("canard_test_augm_full.hf")
canard_df = canard_train_augm.to_pandas()
canard_test_df = canard_train_augm.to_pandas()
### Curation of seq2seq input contexts and labels
import random
def input_context_from_sample(row: dict, max_length=5) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nSearch results:"
all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
random.shuffle(all_contexts)
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
def rephrasing_context_from_sample(row: dict) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nMore specific question: "
return context
def hotpotqa_context(row: dict) -> str:
context = "Current Question: "
context += row["question"]
context += "\nSearch results:"
all_contexts = [" ".join(context) for context in row["context"]["sentences"]]
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
# context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values
too_long_index = [len(t) > 20000 for t in input_texts]
input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
print("training on %s samples" % len(input_texts))
labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
labels = [l for i, l in enumerate(labels) if not too_long_index[i]]
val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
print(rephrasing_inputs[0])
rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_labels = canard_df.Rewrite.values
rephrasing_val_labels = canard_test_df.Rewrite.values
print(rephrasing_labels[0])
# Training
from adaptor.lang_module import LangModule
lang_module = LangModule("google/t5-large-lm-adapt")
from adaptor.evaluators.generative import ROUGE, BLEU
evaluators = [BLEU(), ROUGE()]
from adaptor.objectives.seq2seq import Sequence2Sequence
seq_qa = Sequence2Sequence(lang_module,
texts_or_path=input_texts,
labels_or_path=labels,
val_texts_or_path=input_val_texts,
val_labels_or_path=val_labels,
batch_size=4,
val_evaluators=evaluators,
objective_id="Canard")
hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]
hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
too_long_index = [len(t) > 20000 for t in hotpot_inputs]
hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]
seq_additional_qa = Sequence2Sequence(lang_module,
texts_or_path=hotpot_inputs,
labels_or_path=hotpot_answers,
val_texts_or_path=hotpot_val_inputs[:200],
val_labels_or_path=hotpot_val["answer"][:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="HotpotQA",
share_other_objective_head=seq_qa)
seq_rephrasing = Sequence2Sequence(lang_module,
texts_or_path=rephrasing_inputs,
labels_or_path=rephrasing_labels,
val_texts_or_path=rephrasing_val_inputs[:200],
val_labels_or_path=rephrasing_val_labels[:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="rephrasing",
share_other_objective_head=seq_qa)
from adaptor.utils import AdaptationArguments, StoppingStrategy
training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
learning_rate=5e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
stopping_patience=8,
save_total_limit=8,
do_train=True,
do_eval=True,
bf16=True,
warmup_steps=1000,
gradient_accumulation_steps=8,
logging_steps=10,
eval_steps=200,
save_steps=1000,
num_train_epochs=10,
evaluation_strategy="steps")
from adaptor.schedules import ParallelSchedule
from adaptor.adapter import Adapter
schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()
```
## Usage
See the prompting templates used in training to infer the optimal prompting format.
#### Contact
Feel free to ask questions at stefanik{at} gaussalgo.com |