Text2Question / questiongenerator.py
bhaskartripathi's picture
Create questiongenerator.py
6308102
raw
history blame contribute delete
No virus
12.6 kB
import os
import sys
import math
import numpy as np
import torch
import spacy
import re
import random
import json
import en_core_web_sm
from string import punctuation
#from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
#from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
class QuestionGenerator():
def __init__(self, model_dir=None):
QG_PRETRAINED = 'iarfmoose/t5-base-question-generator'
self.ANSWER_TOKEN = '<answer>'
self.CONTEXT_TOKEN = '<context>'
self.SEQ_LENGTH = 512
self.device = torch.device('cpu')
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.qg_tokenizer = AutoTokenizer.from_pretrained(QG_PRETRAINED)
self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
self.qg_model.to(self.device)
self.qa_evaluator = QAEvaluator(model_dir)
def generate(self, article, use_evaluator=True, num_questions=None, answer_style='all'):
print("Generating questions...\n")
qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
print("qg_inputs, qg_answers=>",qg_inputs, qg_answers)
generated_questions = self.generate_questions_from_inputs(qg_inputs,num_questions)
print("generated_questions(generate)=>",generated_questions)
return generated_questions
message = "{} questions doesn't match {} answers".format(
len(generated_questions),
len(qg_answers))
assert len(generated_questions) == len(qg_answers), message
if use_evaluator:
print("Evaluating QA pairs...\n")
encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(generated_questions, qg_answers)
scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
if num_questions:
qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores, num_questions)
else:
qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores)
else:
print("Skipping evaluation step.\n")
qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
return qa_list
def generate_qg_inputs(self, text, answer_style):
VALID_ANSWER_STYLES = ['all', 'sentences', 'multiple_choice']
if answer_style not in VALID_ANSWER_STYLES:
raise ValueError(
"Invalid answer style {}. Please choose from {}".format(
answer_style,
VALID_ANSWER_STYLES
)
)
inputs = []
answers = []
if answer_style == 'sentences' or answer_style == 'all':
segments = self._split_into_segments(text)
for segment in segments:
sentences = self._split_text(segment)
prepped_inputs, prepped_answers = self._prepare_qg_inputs(sentences, segment)
inputs.extend(prepped_inputs)
answers.extend(prepped_answers)
if answer_style == 'multiple_choice' or answer_style == 'all':
sentences = self._split_text(text)
prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(sentences)
inputs.extend(prepped_inputs)
answers.extend(prepped_answers)
return inputs, answers
def generate_questions_from_inputs(self, qg_inputs,num_questions):
generated_questions = []
count = 0
print("num que => ", num_questions)
for qg_input in qg_inputs:
if count < int(num_questions):
question = self._generate_question(qg_input)
question = question.strip() #remove trailing spaces
question = question.strip(punctuation) #remove trailing questionmarks
question += "?" #add one ?
if question not in generated_questions:
generated_questions.append(question)
print("question ===> ",question)
count += 1
else:
return generated_questions
return generated_questions #
def _split_text(self, text):
MAX_SENTENCE_LEN = 128
sentences = re.findall('.*?[.!\?]', text)
cut_sentences = []
for sentence in sentences:
if len(sentence) > MAX_SENTENCE_LEN:
cut_sentences.extend(re.split('[,;:)]', sentence))
# temporary solution to remove useless post-quote sentence fragments
cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
sentences = sentences + cut_sentences
return list(set([s.strip(" ") for s in sentences]))
def _split_into_segments(self, text):
MAX_TOKENS = 490
paragraphs = text.split('\n')
tokenized_paragraphs = [self.qg_tokenizer(p)['input_ids'] for p in paragraphs if len(p) > 0]
segments = []
while len(tokenized_paragraphs) > 0:
segment = []
while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
paragraph = tokenized_paragraphs.pop(0)
segment.extend(paragraph)
segments.append(segment)
return [self.qg_tokenizer.decode(s) for s in segments]
def _prepare_qg_inputs(self, sentences, text):
inputs = []
answers = []
for sentence in sentences:
qg_input = '{} {} {} {}'.format(
self.ANSWER_TOKEN,
sentence,
self.CONTEXT_TOKEN,
text
)
inputs.append(qg_input)
answers.append(sentence)
return inputs, answers
def _prepare_qg_inputs_MC(self, sentences):
spacy_nlp = en_core_web_sm.load()
docs = list(spacy_nlp.pipe(sentences, disable=['parser']))
inputs_from_text = []
answers_from_text = []
for i in range(len(sentences)):
entities = docs[i].ents
if entities:
for entity in entities:
qg_input = '{} {} {} {}'.format(
self.ANSWER_TOKEN,
entity,
self.CONTEXT_TOKEN,
sentences[i]
)
answers = self._get_MC_answers(entity, docs)
inputs_from_text.append(qg_input)
answers_from_text.append(answers)
return inputs_from_text, answers_from_text
def _get_MC_answers(self, correct_answer, docs):
entities = []
for doc in docs:
entities.extend([{'text': e.text, 'label_': e.label_} for e in doc.ents])
# remove duplicate elements
entities_json = [json.dumps(kv) for kv in entities]
pool = set(entities_json)
num_choices = min(4, len(pool)) - 1 # -1 because we already have the correct answer
# add the correct answer
final_choices = []
correct_label = correct_answer.label_
final_choices.append({'answer': correct_answer.text, 'correct': True})
pool.remove(json.dumps({'text': correct_answer.text, 'label_': correct_answer.label_}))
# find answers with the same NER label
matches = [e for e in pool if correct_label in e]
# if we don't have enough then add some other random answers
if len(matches) < num_choices:
choices = matches
pool = pool.difference(set(choices))
choices.extend(random.sample(pool, num_choices - len(choices)))
else:
choices = random.sample(matches, num_choices)
choices = [json.loads(s) for s in choices]
for choice in choices:
final_choices.append({'answer': choice['text'], 'correct': False})
random.shuffle(final_choices)
return final_choices
def _generate_question(self, qg_input):
self.qg_model.eval()
encoded_input = self._encode_qg_input(qg_input)
with torch.no_grad():
output = self.qg_model.generate(input_ids=encoded_input['input_ids'])
return self.qg_tokenizer.decode(output[0])
def _encode_qg_input(self, qg_input):
return self.qg_tokenizer(
qg_input,
pad_to_max_length=True,
max_length=self.SEQ_LENGTH,
truncation=True,
return_tensors="pt"
).to(self.device)
def _get_ranked_qa_pairs(self, generated_questions, qg_answers, scores, num_questions=10):
if num_questions > len(scores):
num_questions = len(scores)
print("\nWas only able to generate {} questions. For more questions, please input a longer text.".format(num_questions))
qa_list = []
for i in range(num_questions):
index = scores[i]
qa = self._make_dict(
generated_questions[index].split('?')[0] + '?',
qg_answers[index])
qa_list.append(qa)
return qa_list
def _get_all_qa_pairs(self, generated_questions, qg_answers):
qa_list = []
for i in range(len(generated_questions)):
qa = self._make_dict(
generated_questions[i].split('?')[0] + '?',
qg_answers[i])
qa_list.append(qa)
return qa_list
def _make_dict(self, question, answer):
qa = {}
qa['question'] = question
qa['answer'] = answer
return qa
class QAEvaluator():
def __init__(self, model_dir=None):
QAE_PRETRAINED = 'iarfmoose/bert-base-cased-qa-evaluator'
self.SEQ_LENGTH = 512
self.device = torch.device('cpu')
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
self.qae_model = AutoModelForSequenceClassification.from_pretrained(QAE_PRETRAINED)
self.qae_model.to(self.device)
def encode_qa_pairs(self, questions, answers):
encoded_pairs = []
for i in range(len(questions)):
encoded_qa = self._encode_qa(questions[i], answers[i])
encoded_pairs.append(encoded_qa.to(self.device))
return encoded_pairs
def get_scores(self, encoded_qa_pairs):
scores = {}
self.qae_model.eval()
with torch.no_grad():
for i in range(len(encoded_qa_pairs)):
scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
return [k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)]
def _encode_qa(self, question, answer):
if type(answer) is list:
for a in answer:
if a['correct']:
correct_answer = a['answer']
else:
correct_answer = answer
return self.qae_tokenizer(
text=question,
text_pair=correct_answer,
pad_to_max_length=True,
max_length=self.SEQ_LENGTH,
truncation=True,
return_tensors="pt"
)
def _evaluate_qa(self, encoded_qa_pair):
output = self.qae_model(**encoded_qa_pair)
return output[0][0][1]
def print_qa(qa_list, show_answers=True):
for i in range(len(qa_list)):
space = ' ' * int(np.where(i < 9, 3, 4)) # wider space for 2 digit q nums
print('{}) Q: {}'.format(i + 1, qa_list[i]['question']))
answer = qa_list[i]['answer']
# print a list of multiple choice answers
if type(answer) is list:
if show_answers:
print('{}A: 1.'.format(space),
answer[0]['answer'],
np.where(answer[0]['correct'], '(correct)', ''))
for j in range(1, len(answer)):
print('{}{}.'.format(space + ' ', j + 1),
answer[j]['answer'],
np.where(answer[j]['correct'] == True, '(correct)', ''))
else:
print('{}A: 1.'.format(space),
answer[0]['answer'])
for j in range(1, len(answer)):
print('{}{}.'.format(space + ' ', j + 1),
answer[j]['answer'])
print('')
# print full sentence answers
else:
if show_answers:
print('{}A:'.format(space), answer, '\n')