File size: 10,868 Bytes
32980f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import logging
import random
from typing import Iterable, Union, Dict, List, Optional
import torch
from adaptor.objectives.seq2seq import Sequence2Sequence
from transformers import BatchEncoding
logger = logging.getLogger()
priming_formats = {
"QA": {"cs": "Otázka: %s Kontext: %s Odpověď:",
"en": "Question: %s Context: %s Answer:",
"ru": "Вопрос: %s Контекст: %s Отвечать:"}}
class Priming(Sequence2Sequence):
def __init__(self, *args,
train_question_categories: Iterable[str],
max_eval_samples: int,
val_question_categories: Optional[Iterable[str]] = None,
min_num_demonstrations: int = 2,
max_num_demonstrations: int = 5,
demos_infer_batch_size: int = 32,
demos_selection_strategy: str = "hard",
difficulty_sample: int = 64,
max_input_length: int = 8000,
**kwargs):
super().__init__(*args, **kwargs)
self.train_question_categories = list(train_question_categories)
self.val_question_categories = list(val_question_categories) if val_question_categories is not None else None
self.min_num_demonstrations = min_num_demonstrations
self.max_num_demonstrations = max_num_demonstrations
self.demos_infer_batch_size = demos_infer_batch_size
self.demos_selection_strategy = demos_selection_strategy
self.difficulty_sample = difficulty_sample
self.max_input_length = max_input_length
self.max_eval_samples = max_eval_samples
def _construct_qa_prompt(self, question: str, context: str) -> str:
return priming_formats["QA"][self.source_lang_id] % (question, context)
def _construct_demonstration(self, prompt: str, answer: str) -> str:
return "%s %s " % (prompt, answer)
def _construct_primed_prompt(self, primed_demonstrations: List[str], prompt: str) -> str:
return " ".join(primed_demonstrations) + " " + prompt
def forced_generation_score(self, input_texts: List[str], forced_output: str) -> torch.FloatTensor:
inputs = self.tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True)
inputs = inputs.to(self.compatible_head_model.device)
with self.tokenizer.as_target_tokenizer():
output_ids = self.tokenizer(forced_output, return_tensors="pt", padding="longest",
truncation=True).input_ids.to(self.compatible_head_model.device)
forced_outputs = self.compatible_head_model.prepare_decoder_input_ids_from_labels(output_ids)
forced_outputs = forced_outputs.to(self.compatible_head_model.device)
outputs = self.compatible_head_model(**inputs,
decoder_input_ids=forced_outputs.expand(inputs.input_ids.shape[0], -1))
output_log_probs = outputs.logits.log_softmax(-1)
forced_output_logits = torch.gather(output_log_probs, -1,
output_ids.expand(inputs.input_ids.shape[0], -1).unsqueeze(-1))
forced_output_log_score = forced_output_logits.sum((-1, -2))
# we do not need to normalize, as all the targets are the same <=> same length
return forced_output_log_score.double().exp()
def _pick_most_difficult_demo(self,
selected_demos: List[str],
next_demo_cands: List[str],
predict_prompt: str,
predicted_answer: str) -> int:
with torch.no_grad():
difficulties = torch.empty(0, device=self.compatible_head_model.device, dtype=torch.float)
for batch_offset in range(0, len(next_demo_cands), self.demos_infer_batch_size):
next_demo_cands_batch = next_demo_cands[batch_offset: batch_offset + self.demos_infer_batch_size]
primed_prompts = [self._construct_primed_prompt(selected_demos + [demo], predict_prompt)
for demo in next_demo_cands_batch]
cands_difficulty = self.forced_generation_score(primed_prompts, predicted_answer)
difficulties = torch.hstack((difficulties, cands_difficulty))
assert difficulties.argmin() < len(next_demo_cands)
return difficulties.argmin()
def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
"""
Creates a default iterator over encodings with aligned input and output texts.
:param split: Data split. `train` or `eval`.
:return: Iterator of model input encodings.
"""
# we materialize all samples in memory, so that we can heuristically pick the combinations
questions, contexts, answers = (list(it) for it in self._per_split_iterators(split))
question_categories = self.train_question_categories if split == "train" else self.val_question_categories
assert len(questions) == len(contexts) == len(answers) == len(question_categories), \
"Given numbers of questions, contexts and answers do not match."
prompts = [self._construct_qa_prompt(q, c) for q, c in zip(questions, contexts)]
features_batch = []
cat_index = {cat: [i for i, sample_cat in enumerate(question_categories) if cat == sample_cat]
for cat in set(question_categories)}
retrieved_samples = 0
for idx, sample_category in enumerate(question_categories):
if not cat_index[sample_category]:
logger.warning("No samples within the category %s", sample_category)
continue
pred_prompt, pred_answer = prompts[idx], answers[idx]
picked_demonstrations = []
# a number of demonstrations is in the specified range
expected_num_demonstrations = random.randint(self.min_num_demonstrations, self.max_num_demonstrations)
while len(picked_demonstrations) < expected_num_demonstrations:
if sum(map(len, picked_demonstrations)) > self.max_input_length:
logger.warning("Skipping too long prompt.")
break
if self.demos_selection_strategy == "hard":
# pick the most difficult examples out of a sample
# we do not need to worry for picking up the predicted sample among demonstrations in hard strategy
if len(cat_index[sample_category]) <= 1:
# we can not construct informative demonstrations for categories of a single item
break
samples_idx = random.choices(cat_index[sample_category], k=self.difficulty_sample)
cand_demonstrations = [self._construct_demonstration(prompts[i], answers[i]) for i in samples_idx]
selected_index = self._pick_most_difficult_demo(picked_demonstrations, cand_demonstrations,
pred_prompt, pred_answer)
picked_demonstrations.append(cand_demonstrations[selected_index])
elif self.demos_selection_strategy == "informative":
if len(cat_index[sample_category]) <= 1:
# we can not construct informative demonstrations for categories of a single item
break
selected_cat_index = random.randint(1, len(cat_index[sample_category])-1)
selected_index = cat_index[sample_category][selected_cat_index]
if selected_index == idx:
# we do not want to expose the predicted sample in demonstrations
selected_index = cat_index[sample_category][selected_cat_index-1]
picked_demonstration = self._construct_demonstration(prompts[selected_index],
answers[selected_index])
picked_demonstrations.append(picked_demonstration)
elif self.demos_selection_strategy == "random":
# evaluation: do not infer samples' difficulty, pick randomly
selected_index = random.randint(1, len(prompts)-1)
if selected_index == idx:
# we do not want to expose the predicted sample in demonstrations
selected_index -= 1
picked_demonstration = self._construct_demonstration(prompts[selected_index],
answers[selected_index])
picked_demonstrations.append(picked_demonstration)
else:
raise ValueError("Unknown demon selection strategy: '%s'" % self.demos_selection_strategy)
if len(picked_demonstrations) != expected_num_demonstrations:
# we omit examples with none or only one demonstration in the category
continue
# encode a yielded batch
primed_prompt = self._construct_primed_prompt(picked_demonstrations, pred_prompt)
primed_prompt_encoding = self.tokenizer(primed_prompt, truncation=True)
label_encoding = self.tokenizer(pred_answer, truncation=True)
features_batch.append({"input_ids": primed_prompt_encoding.input_ids,
"attention_mask": primed_prompt_encoding.attention_mask,
"labels": label_encoding.input_ids})
if len(features_batch) == self.batch_size:
yield self.collator(features_batch)
features_batch = []
retrieved_samples += 1
if split == "eval" and retrieved_samples >= self.max_eval_samples:
# custom evaluation break - we need all samples in set to match categories,
# but do not want to iterate them all
break
if features_batch:
# yield last nonempty residual batch
yield self.collator(features_batch)
def _compute_loss(self,
lm_logit_outputs: torch.FloatTensor,
labels: torch.LongTensor,
inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor:
# customization for mt5 model, with incorrectly-set tokenizer.vocab_size
# This should be fixed in upcoming release of adaptor (>=0.1.5)
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(lm_logit_outputs.flatten(end_dim=1), labels.flatten())
return lm_loss
|