mTk-AdversarialQA_en-SberQuAD_ru-1B / priming_objective.py
michal-stefanik's picture
README & training scripts
32980f3
raw
history blame
No virus
10.9 kB
import logging
import random
from typing import Iterable, Union, Dict, List, Optional
import torch
from adaptor.objectives.seq2seq import Sequence2Sequence
from transformers import BatchEncoding
logger = logging.getLogger()
priming_formats = {
"QA": {"cs": "Otázka: %s Kontext: %s Odpověď:",
"en": "Question: %s Context: %s Answer:",
"ru": "Вопрос: %s Контекст: %s Отвечать:"}}
class Priming(Sequence2Sequence):
def __init__(self, *args,
train_question_categories: Iterable[str],
max_eval_samples: int,
val_question_categories: Optional[Iterable[str]] = None,
min_num_demonstrations: int = 2,
max_num_demonstrations: int = 5,
demos_infer_batch_size: int = 32,
demos_selection_strategy: str = "hard",
difficulty_sample: int = 64,
max_input_length: int = 8000,
**kwargs):
super().__init__(*args, **kwargs)
self.train_question_categories = list(train_question_categories)
self.val_question_categories = list(val_question_categories) if val_question_categories is not None else None
self.min_num_demonstrations = min_num_demonstrations
self.max_num_demonstrations = max_num_demonstrations
self.demos_infer_batch_size = demos_infer_batch_size
self.demos_selection_strategy = demos_selection_strategy
self.difficulty_sample = difficulty_sample
self.max_input_length = max_input_length
self.max_eval_samples = max_eval_samples
def _construct_qa_prompt(self, question: str, context: str) -> str:
return priming_formats["QA"][self.source_lang_id] % (question, context)
def _construct_demonstration(self, prompt: str, answer: str) -> str:
return "%s %s " % (prompt, answer)
def _construct_primed_prompt(self, primed_demonstrations: List[str], prompt: str) -> str:
return " ".join(primed_demonstrations) + " " + prompt
def forced_generation_score(self, input_texts: List[str], forced_output: str) -> torch.FloatTensor:
inputs = self.tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True)
inputs = inputs.to(self.compatible_head_model.device)
with self.tokenizer.as_target_tokenizer():
output_ids = self.tokenizer(forced_output, return_tensors="pt", padding="longest",
truncation=True).input_ids.to(self.compatible_head_model.device)
forced_outputs = self.compatible_head_model.prepare_decoder_input_ids_from_labels(output_ids)
forced_outputs = forced_outputs.to(self.compatible_head_model.device)
outputs = self.compatible_head_model(**inputs,
decoder_input_ids=forced_outputs.expand(inputs.input_ids.shape[0], -1))
output_log_probs = outputs.logits.log_softmax(-1)
forced_output_logits = torch.gather(output_log_probs, -1,
output_ids.expand(inputs.input_ids.shape[0], -1).unsqueeze(-1))
forced_output_log_score = forced_output_logits.sum((-1, -2))
# we do not need to normalize, as all the targets are the same <=> same length
return forced_output_log_score.double().exp()
def _pick_most_difficult_demo(self,
selected_demos: List[str],
next_demo_cands: List[str],
predict_prompt: str,
predicted_answer: str) -> int:
with torch.no_grad():
difficulties = torch.empty(0, device=self.compatible_head_model.device, dtype=torch.float)
for batch_offset in range(0, len(next_demo_cands), self.demos_infer_batch_size):
next_demo_cands_batch = next_demo_cands[batch_offset: batch_offset + self.demos_infer_batch_size]
primed_prompts = [self._construct_primed_prompt(selected_demos + [demo], predict_prompt)
for demo in next_demo_cands_batch]
cands_difficulty = self.forced_generation_score(primed_prompts, predicted_answer)
difficulties = torch.hstack((difficulties, cands_difficulty))
assert difficulties.argmin() < len(next_demo_cands)
return difficulties.argmin()
def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
"""
Creates a default iterator over encodings with aligned input and output texts.
:param split: Data split. `train` or `eval`.
:return: Iterator of model input encodings.
"""
# we materialize all samples in memory, so that we can heuristically pick the combinations
questions, contexts, answers = (list(it) for it in self._per_split_iterators(split))
question_categories = self.train_question_categories if split == "train" else self.val_question_categories
assert len(questions) == len(contexts) == len(answers) == len(question_categories), \
"Given numbers of questions, contexts and answers do not match."
prompts = [self._construct_qa_prompt(q, c) for q, c in zip(questions, contexts)]
features_batch = []
cat_index = {cat: [i for i, sample_cat in enumerate(question_categories) if cat == sample_cat]
for cat in set(question_categories)}
retrieved_samples = 0
for idx, sample_category in enumerate(question_categories):
if not cat_index[sample_category]:
logger.warning("No samples within the category %s", sample_category)
continue
pred_prompt, pred_answer = prompts[idx], answers[idx]
picked_demonstrations = []
# a number of demonstrations is in the specified range
expected_num_demonstrations = random.randint(self.min_num_demonstrations, self.max_num_demonstrations)
while len(picked_demonstrations) < expected_num_demonstrations:
if sum(map(len, picked_demonstrations)) > self.max_input_length:
logger.warning("Skipping too long prompt.")
break
if self.demos_selection_strategy == "hard":
# pick the most difficult examples out of a sample
# we do not need to worry for picking up the predicted sample among demonstrations in hard strategy
if len(cat_index[sample_category]) <= 1:
# we can not construct informative demonstrations for categories of a single item
break
samples_idx = random.choices(cat_index[sample_category], k=self.difficulty_sample)
cand_demonstrations = [self._construct_demonstration(prompts[i], answers[i]) for i in samples_idx]
selected_index = self._pick_most_difficult_demo(picked_demonstrations, cand_demonstrations,
pred_prompt, pred_answer)
picked_demonstrations.append(cand_demonstrations[selected_index])
elif self.demos_selection_strategy == "informative":
if len(cat_index[sample_category]) <= 1:
# we can not construct informative demonstrations for categories of a single item
break
selected_cat_index = random.randint(1, len(cat_index[sample_category])-1)
selected_index = cat_index[sample_category][selected_cat_index]
if selected_index == idx:
# we do not want to expose the predicted sample in demonstrations
selected_index = cat_index[sample_category][selected_cat_index-1]
picked_demonstration = self._construct_demonstration(prompts[selected_index],
answers[selected_index])
picked_demonstrations.append(picked_demonstration)
elif self.demos_selection_strategy == "random":
# evaluation: do not infer samples' difficulty, pick randomly
selected_index = random.randint(1, len(prompts)-1)
if selected_index == idx:
# we do not want to expose the predicted sample in demonstrations
selected_index -= 1
picked_demonstration = self._construct_demonstration(prompts[selected_index],
answers[selected_index])
picked_demonstrations.append(picked_demonstration)
else:
raise ValueError("Unknown demon selection strategy: '%s'" % self.demos_selection_strategy)
if len(picked_demonstrations) != expected_num_demonstrations:
# we omit examples with none or only one demonstration in the category
continue
# encode a yielded batch
primed_prompt = self._construct_primed_prompt(picked_demonstrations, pred_prompt)
primed_prompt_encoding = self.tokenizer(primed_prompt, truncation=True)
label_encoding = self.tokenizer(pred_answer, truncation=True)
features_batch.append({"input_ids": primed_prompt_encoding.input_ids,
"attention_mask": primed_prompt_encoding.attention_mask,
"labels": label_encoding.input_ids})
if len(features_batch) == self.batch_size:
yield self.collator(features_batch)
features_batch = []
retrieved_samples += 1
if split == "eval" and retrieved_samples >= self.max_eval_samples:
# custom evaluation break - we need all samples in set to match categories,
# but do not want to iterate them all
break
if features_batch:
# yield last nonempty residual batch
yield self.collator(features_batch)
def _compute_loss(self,
lm_logit_outputs: torch.FloatTensor,
labels: torch.LongTensor,
inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor:
# customization for mt5 model, with incorrectly-set tokenizer.vocab_size
# This should be fixed in upcoming release of adaptor (>=0.1.5)
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(lm_logit_outputs.flatten(end_dim=1), labels.flatten())
return lm_loss