from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from typing import Dict, List, Any import itertools from nltk import sent_tokenize # import torch import nltk class PreTrainedPipeline(): def __init__(self, path=""): # IMPLEMENT_THIS # Preload all the elements you are going to need at inference. # For instance your model, processors, tokenizer that might be needed. # This function is only called once, so do all the heavy processing I/O here""" nltk.download('punkt') self.model = AutoModelForSeq2SeqLM.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path) self.model_type="t5" # self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cpu" self.model.to(self.device) def __call__(self, inputs: str, max_words_per_answer: int = 3): if len(inputs) == 0: return [] inputs = " ".join(inputs.split()) sents, answers = self._extract_answers(inputs) flat_answers = list(itertools.chain(*answers)) if len(flat_answers) == 0: return [] questions, qg_examples = self.prepare_and_generate_questions(sents, answers) output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)] output = self.clean_generated_QAs(output, max_words_per_answer) return output def prepare_and_generate_questions(self, sents, answers): qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers) qg_inputs = [example['source_text'] for example in qg_examples] questions = self._generate_questions(qg_inputs) return questions, qg_examples def clean_answers_list_of_lists(self, answers): clean_answers = [] for answer_list in answers: answer_list = answer_list[:-1] answer_list = list(set([a.strip() for a in answer_list])) clean_answers.append(answer_list) return clean_answers def _extract_answers(self, context): sents, inputs = self._prepare_inputs_for_ans_extraction(context) inputs = self._tokenize(inputs, padding=True, truncation=True) outs = self.model.generate( input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device), max_length=32, ) dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs] answers = [item.split('') for item in dec] answers = self.clean_answers_list_of_lists(answers) return sents, answers def _prepare_inputs_for_ans_extraction(self, text): sents = sent_tokenize(text) inputs = [] for i in range(len(sents)): source_text = "extract answers:" for j, sent in enumerate(sents): if i == j: sent = " %s " % sent source_text = "%s %s" % (source_text, sent) source_text = source_text.strip() if self.model_type == "t5": source_text = source_text + " " inputs.append(source_text) return sents, inputs def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512 ): inputs = self.tokenizer.batch_encode_plus( inputs, max_length=max_length, add_special_tokens=add_special_tokens, truncation=truncation, padding="max_length" if padding else False, pad_to_max_length=padding, return_tensors="pt" ) return inputs def _generate_questions(self, inputs): inputs = self._tokenize(inputs, padding=True, truncation=True) outs = self.model.generate( input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device), max_length=32, num_beams=4, ) questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] return questions def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers): inputs = [] for i, answer in enumerate(answers): if len(answer) == 0: continue for answer_text in answer: sent = sents[i] sents_copy = sents[:] answer_text = self.remove_pad(answer_text) answer_text = answer_text.strip() try: ans_start_idx = sent.lower().index(answer_text.lower()) except ValueError: # Means the answer is not in the sentence so we skip this one continue sent = f"{sent[:ans_start_idx]} {answer_text} {sent[ans_start_idx + len(answer_text): ]}" sents_copy[i] = sent source_text = " ".join(sents_copy) source_text = f"generate question: {source_text}" if self.model_type == "t5": source_text = source_text + " " inputs.append({"answer": answer_text, "source_text": source_text}) return inputs def clean_generated_QAs(self, generated_QAs, max_words_per_answer): clean_QAs = [] answers_used = set() # Only allow 1 question per answer, take the first case of it for qa in generated_QAs: answer_word_length = len(qa['answer'].strip().split()) if qa['answer'] in answers_used or answer_word_length > max_words_per_answer: continue answers_used.add(qa['answer']) clean_QAs.append(qa) return clean_QAs def remove_pad(self, str): if "" in str: return str.replace("", "") return str