Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- README.md +5 -5
- app.py +146 -0
- gitattributes +34 -0
- questiongenerator.py +345 -0
- requirements (1).txt +15 -0
- run_qg.py +73 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Question Generation Using T5
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.44.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
import socket
|
8 |
+
import huggingface_hub
|
9 |
+
from huggingface_hub import Repository
|
10 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
|
11 |
+
from questiongenerator import QuestionGenerator
|
12 |
+
import csv
|
13 |
+
from urllib.request import urlopen
|
14 |
+
import re as r
|
15 |
+
|
16 |
+
qg = QuestionGenerator()
|
17 |
+
|
18 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
19 |
+
DATASET_NAME = "question_generation_T5_dataset"
|
20 |
+
DATASET_REPO_URL = f"https://huggingface.co/datasets/pragnakalp/{DATASET_NAME}"
|
21 |
+
DATA_FILENAME = "que_gen_logs.csv"
|
22 |
+
DATA_FILE = os.path.join("que_gen_logs", DATA_FILENAME)
|
23 |
+
DATASET_REPO_ID = "pragnakalp/question_generation_T5_dataset"
|
24 |
+
print("is none?", HF_TOKEN is None)
|
25 |
+
article_value = """Google was founded in 1998 by Larry Page and Sergey Brin while they were Ph.D. students at Stanford University in California. Together they own about 14 percent of its shares and control 56 percent of the stockholder voting power through supervoting stock. They incorporated Google as a privately held company on September 4, 1998. An initial public offering (IPO) took place on August 19, 2004, and Google moved to its headquarters in Mountain View, California, nicknamed the Googleplex. In August 2015, Google announced plans to reorganize its various interests as a conglomerate called Alphabet Inc. Google is Alphabet's leading subsidiary and will continue to be the umbrella company for Alphabet's Internet interests. Sundar Pichai was appointed CEO of Google, replacing Larry Page who became the CEO of Alphabet."""
|
26 |
+
# REPOSITORY_DIR = "data"
|
27 |
+
# LOCAL_DIR = 'data_local'
|
28 |
+
# os.makedirs(LOCAL_DIR,exist_ok=True)
|
29 |
+
|
30 |
+
try:
|
31 |
+
hf_hub_download(
|
32 |
+
repo_id=DATASET_REPO_ID,
|
33 |
+
filename=DATA_FILENAME,
|
34 |
+
cache_dir=DATA_DIRNAME,
|
35 |
+
force_filename=DATA_FILENAME
|
36 |
+
)
|
37 |
+
|
38 |
+
except:
|
39 |
+
print("file not found")
|
40 |
+
|
41 |
+
repo = Repository(
|
42 |
+
local_dir="que_gen_logs", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def getIP():
|
47 |
+
ip_address = ''
|
48 |
+
try:
|
49 |
+
d = str(urlopen('http://checkip.dyndns.com/')
|
50 |
+
.read())
|
51 |
+
|
52 |
+
return r.compile(r'Address: (\d+\.\d+\.\d+\.\d+)').search(d).group(1)
|
53 |
+
except Exception as e:
|
54 |
+
print("Error while getting IP address -->",e)
|
55 |
+
return ip_address
|
56 |
+
|
57 |
+
def get_location(ip_addr):
|
58 |
+
location = {}
|
59 |
+
try:
|
60 |
+
ip=ip_addr
|
61 |
+
|
62 |
+
req_data={
|
63 |
+
"ip":ip,
|
64 |
+
"token":"pkml123"
|
65 |
+
}
|
66 |
+
url = "https://demos.pragnakalp.com/get-ip-location"
|
67 |
+
|
68 |
+
# req_data=json.dumps(req_data)
|
69 |
+
# print("req_data",req_data)
|
70 |
+
headers = {'Content-Type': 'application/json'}
|
71 |
+
|
72 |
+
response = requests.request("POST", url, headers=headers, data=json.dumps(req_data))
|
73 |
+
response = response.json()
|
74 |
+
print("response======>>",response)
|
75 |
+
return response
|
76 |
+
except Exception as e:
|
77 |
+
print("Error while getting location -->",e)
|
78 |
+
return location
|
79 |
+
|
80 |
+
def generate_questions(article,num_que):
|
81 |
+
result = ''
|
82 |
+
if article.strip():
|
83 |
+
if num_que == None or num_que == '':
|
84 |
+
num_que = 3
|
85 |
+
else:
|
86 |
+
num_que = num_que
|
87 |
+
generated_questions_list = qg.generate(article, num_questions=int(num_que))
|
88 |
+
summarized_data = {
|
89 |
+
"generated_questions" : generated_questions_list
|
90 |
+
}
|
91 |
+
generated_questions = summarized_data.get("generated_questions",'')
|
92 |
+
|
93 |
+
for q in generated_questions:
|
94 |
+
print(q)
|
95 |
+
result = result + q + '\n'
|
96 |
+
save_data_and_sendmail(article,generated_questions,num_que)
|
97 |
+
print("sending result***!!!!!!", result)
|
98 |
+
return result
|
99 |
+
else:
|
100 |
+
raise gr.Error("Please enter text in inputbox!!!!")
|
101 |
+
|
102 |
+
"""
|
103 |
+
Save generated details
|
104 |
+
"""
|
105 |
+
def save_data_and_sendmail(article,generated_questions,num_que):
|
106 |
+
try:
|
107 |
+
ip_address= getIP()
|
108 |
+
print(ip_address)
|
109 |
+
location = get_location(ip_address)
|
110 |
+
print(location)
|
111 |
+
add_csv = [article, generated_questions, num_que, ip_address,location]
|
112 |
+
print("data^^^^^",add_csv)
|
113 |
+
with open(DATA_FILE, "a") as f:
|
114 |
+
writer = csv.writer(f)
|
115 |
+
# write the data
|
116 |
+
writer.writerow(add_csv)
|
117 |
+
commit_url = repo.push_to_hub()
|
118 |
+
print("commit data :",commit_url)
|
119 |
+
|
120 |
+
url = 'https://pragnakalpdev35.pythonanywhere.com/HF_space_que_gen'
|
121 |
+
# url = 'http://pragnakalpdev33.pythonanywhere.com/HF_space_question_generator'
|
122 |
+
myobj = {'article': article,'total_que': num_que,'gen_que':generated_questions,'ip_addr':ip_address,'loc':location}
|
123 |
+
x = requests.post(url, json = myobj)
|
124 |
+
print("myobj^^^^^",myobj)
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
return "Error while sending mail" + str(e)
|
128 |
+
|
129 |
+
return "Successfully save data"
|
130 |
+
|
131 |
+
## design 1
|
132 |
+
inputs=gr.Textbox(value=article_value, lines=5, label="Input Text/Article",elem_id="inp_div")
|
133 |
+
total_que = gr.Textbox(label="Number of questions to generate",elem_id="inp_div")
|
134 |
+
outputs=gr.Textbox(label="Generated Questions",lines=6,elem_id="inp_div")
|
135 |
+
|
136 |
+
demo = gr.Interface(
|
137 |
+
generate_questions,
|
138 |
+
[inputs,total_que],
|
139 |
+
outputs,
|
140 |
+
title="Question Generation Using T5-Base Model",
|
141 |
+
css=".gradio-container {background-color: lightgray} #inp_div {background-color: #7FB3D5;}",
|
142 |
+
article="""<p style='text-align: center;'>Feel free to give us your <a href="https://www.pragnakalp.com/contact/" target="_blank">feedback</a> on this Question Generation using T5 demo.</p>
|
143 |
+
<p style='text-align: center;'>Developed by: <a href="https://www.pragnakalp.com" target="_blank">Pragnakalp Techlabs</a></p>"""
|
144 |
+
|
145 |
+
)
|
146 |
+
demo.launch()
|
gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
questiongenerator.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import spacy
|
7 |
+
import re
|
8 |
+
import random
|
9 |
+
import json
|
10 |
+
import en_core_web_sm
|
11 |
+
from string import punctuation
|
12 |
+
|
13 |
+
#from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
|
14 |
+
#from transformers import BertTokenizer, BertForSequenceClassification
|
15 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
|
16 |
+
class QuestionGenerator():
|
17 |
+
|
18 |
+
def __init__(self, model_dir=None):
|
19 |
+
|
20 |
+
QG_PRETRAINED = 'iarfmoose/t5-base-question-generator'
|
21 |
+
self.ANSWER_TOKEN = '<answer>'
|
22 |
+
self.CONTEXT_TOKEN = '<context>'
|
23 |
+
self.SEQ_LENGTH = 512
|
24 |
+
|
25 |
+
self.device = torch.device('cpu')
|
26 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
27 |
+
|
28 |
+
self.qg_tokenizer = AutoTokenizer.from_pretrained(QG_PRETRAINED)
|
29 |
+
self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
|
30 |
+
self.qg_model.to(self.device)
|
31 |
+
|
32 |
+
self.qa_evaluator = QAEvaluator(model_dir)
|
33 |
+
|
34 |
+
def generate(self, article, use_evaluator=True, num_questions=None, answer_style='all'):
|
35 |
+
|
36 |
+
print("Generating questions...\n")
|
37 |
+
|
38 |
+
qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
|
39 |
+
print("qg_inputs, qg_answers=>",qg_inputs, qg_answers)
|
40 |
+
generated_questions = self.generate_questions_from_inputs(qg_inputs,num_questions)
|
41 |
+
print("generated_questions(generate)=>",generated_questions)
|
42 |
+
return generated_questions
|
43 |
+
message = "{} questions doesn't match {} answers".format(
|
44 |
+
len(generated_questions),
|
45 |
+
len(qg_answers))
|
46 |
+
assert len(generated_questions) == len(qg_answers), message
|
47 |
+
|
48 |
+
if use_evaluator:
|
49 |
+
|
50 |
+
print("Evaluating QA pairs...\n")
|
51 |
+
|
52 |
+
encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(generated_questions, qg_answers)
|
53 |
+
scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
|
54 |
+
if num_questions:
|
55 |
+
qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores, num_questions)
|
56 |
+
else:
|
57 |
+
qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores)
|
58 |
+
|
59 |
+
else:
|
60 |
+
print("Skipping evaluation step.\n")
|
61 |
+
qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
|
62 |
+
|
63 |
+
return qa_list
|
64 |
+
|
65 |
+
def generate_qg_inputs(self, text, answer_style):
|
66 |
+
|
67 |
+
VALID_ANSWER_STYLES = ['all', 'sentences', 'multiple_choice']
|
68 |
+
|
69 |
+
if answer_style not in VALID_ANSWER_STYLES:
|
70 |
+
raise ValueError(
|
71 |
+
"Invalid answer style {}. Please choose from {}".format(
|
72 |
+
answer_style,
|
73 |
+
VALID_ANSWER_STYLES
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
inputs = []
|
78 |
+
answers = []
|
79 |
+
|
80 |
+
if answer_style == 'sentences' or answer_style == 'all':
|
81 |
+
segments = self._split_into_segments(text)
|
82 |
+
for segment in segments:
|
83 |
+
sentences = self._split_text(segment)
|
84 |
+
prepped_inputs, prepped_answers = self._prepare_qg_inputs(sentences, segment)
|
85 |
+
inputs.extend(prepped_inputs)
|
86 |
+
answers.extend(prepped_answers)
|
87 |
+
|
88 |
+
if answer_style == 'multiple_choice' or answer_style == 'all':
|
89 |
+
sentences = self._split_text(text)
|
90 |
+
prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(sentences)
|
91 |
+
inputs.extend(prepped_inputs)
|
92 |
+
answers.extend(prepped_answers)
|
93 |
+
|
94 |
+
return inputs, answers
|
95 |
+
|
96 |
+
def generate_questions_from_inputs(self, qg_inputs,num_questions):
|
97 |
+
generated_questions = []
|
98 |
+
count = 0
|
99 |
+
print("num que => ", num_questions)
|
100 |
+
for qg_input in qg_inputs:
|
101 |
+
if count < int(num_questions):
|
102 |
+
question = self._generate_question(qg_input)
|
103 |
+
|
104 |
+
question = question.strip() #remove trailing spaces
|
105 |
+
question = question.strip(punctuation) #remove trailing questionmarks
|
106 |
+
question += "?" #add one ?
|
107 |
+
if question not in generated_questions:
|
108 |
+
generated_questions.append(question)
|
109 |
+
print("question ===> ",question)
|
110 |
+
count += 1
|
111 |
+
else:
|
112 |
+
return generated_questions
|
113 |
+
return generated_questions #
|
114 |
+
def _split_text(self, text):
|
115 |
+
MAX_SENTENCE_LEN = 128
|
116 |
+
|
117 |
+
sentences = re.findall('.*?[.!\?]', text)
|
118 |
+
|
119 |
+
cut_sentences = []
|
120 |
+
for sentence in sentences:
|
121 |
+
if len(sentence) > MAX_SENTENCE_LEN:
|
122 |
+
cut_sentences.extend(re.split('[,;:)]', sentence))
|
123 |
+
# temporary solution to remove useless post-quote sentence fragments
|
124 |
+
cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
|
125 |
+
sentences = sentences + cut_sentences
|
126 |
+
|
127 |
+
return list(set([s.strip(" ") for s in sentences]))
|
128 |
+
|
129 |
+
def _split_into_segments(self, text):
|
130 |
+
MAX_TOKENS = 490
|
131 |
+
|
132 |
+
paragraphs = text.split('\n')
|
133 |
+
tokenized_paragraphs = [self.qg_tokenizer(p)['input_ids'] for p in paragraphs if len(p) > 0]
|
134 |
+
|
135 |
+
segments = []
|
136 |
+
while len(tokenized_paragraphs) > 0:
|
137 |
+
segment = []
|
138 |
+
while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
|
139 |
+
paragraph = tokenized_paragraphs.pop(0)
|
140 |
+
segment.extend(paragraph)
|
141 |
+
segments.append(segment)
|
142 |
+
return [self.qg_tokenizer.decode(s) for s in segments]
|
143 |
+
|
144 |
+
def _prepare_qg_inputs(self, sentences, text):
|
145 |
+
inputs = []
|
146 |
+
answers = []
|
147 |
+
|
148 |
+
for sentence in sentences:
|
149 |
+
qg_input = '{} {} {} {}'.format(
|
150 |
+
self.ANSWER_TOKEN,
|
151 |
+
sentence,
|
152 |
+
self.CONTEXT_TOKEN,
|
153 |
+
text
|
154 |
+
)
|
155 |
+
inputs.append(qg_input)
|
156 |
+
answers.append(sentence)
|
157 |
+
|
158 |
+
return inputs, answers
|
159 |
+
|
160 |
+
def _prepare_qg_inputs_MC(self, sentences):
|
161 |
+
|
162 |
+
spacy_nlp = en_core_web_sm.load()
|
163 |
+
docs = list(spacy_nlp.pipe(sentences, disable=['parser']))
|
164 |
+
inputs_from_text = []
|
165 |
+
answers_from_text = []
|
166 |
+
|
167 |
+
for i in range(len(sentences)):
|
168 |
+
entities = docs[i].ents
|
169 |
+
if entities:
|
170 |
+
for entity in entities:
|
171 |
+
qg_input = '{} {} {} {}'.format(
|
172 |
+
self.ANSWER_TOKEN,
|
173 |
+
entity,
|
174 |
+
self.CONTEXT_TOKEN,
|
175 |
+
sentences[i]
|
176 |
+
)
|
177 |
+
answers = self._get_MC_answers(entity, docs)
|
178 |
+
inputs_from_text.append(qg_input)
|
179 |
+
answers_from_text.append(answers)
|
180 |
+
|
181 |
+
return inputs_from_text, answers_from_text
|
182 |
+
|
183 |
+
def _get_MC_answers(self, correct_answer, docs):
|
184 |
+
|
185 |
+
entities = []
|
186 |
+
for doc in docs:
|
187 |
+
entities.extend([{'text': e.text, 'label_': e.label_} for e in doc.ents])
|
188 |
+
|
189 |
+
# remove duplicate elements
|
190 |
+
entities_json = [json.dumps(kv) for kv in entities]
|
191 |
+
pool = set(entities_json)
|
192 |
+
num_choices = min(4, len(pool)) - 1 # -1 because we already have the correct answer
|
193 |
+
|
194 |
+
# add the correct answer
|
195 |
+
final_choices = []
|
196 |
+
correct_label = correct_answer.label_
|
197 |
+
final_choices.append({'answer': correct_answer.text, 'correct': True})
|
198 |
+
pool.remove(json.dumps({'text': correct_answer.text, 'label_': correct_answer.label_}))
|
199 |
+
|
200 |
+
# find answers with the same NER label
|
201 |
+
matches = [e for e in pool if correct_label in e]
|
202 |
+
|
203 |
+
# if we don't have enough then add some other random answers
|
204 |
+
if len(matches) < num_choices:
|
205 |
+
choices = matches
|
206 |
+
pool = pool.difference(set(choices))
|
207 |
+
choices.extend(random.sample(pool, num_choices - len(choices)))
|
208 |
+
else:
|
209 |
+
choices = random.sample(matches, num_choices)
|
210 |
+
|
211 |
+
choices = [json.loads(s) for s in choices]
|
212 |
+
for choice in choices:
|
213 |
+
final_choices.append({'answer': choice['text'], 'correct': False})
|
214 |
+
random.shuffle(final_choices)
|
215 |
+
return final_choices
|
216 |
+
|
217 |
+
def _generate_question(self, qg_input):
|
218 |
+
self.qg_model.eval()
|
219 |
+
encoded_input = self._encode_qg_input(qg_input)
|
220 |
+
with torch.no_grad():
|
221 |
+
output = self.qg_model.generate(input_ids=encoded_input['input_ids'])
|
222 |
+
return self.qg_tokenizer.decode(output[0])
|
223 |
+
|
224 |
+
def _encode_qg_input(self, qg_input):
|
225 |
+
return self.qg_tokenizer(
|
226 |
+
qg_input,
|
227 |
+
pad_to_max_length=True,
|
228 |
+
max_length=self.SEQ_LENGTH,
|
229 |
+
truncation=True,
|
230 |
+
return_tensors="pt"
|
231 |
+
).to(self.device)
|
232 |
+
|
233 |
+
def _get_ranked_qa_pairs(self, generated_questions, qg_answers, scores, num_questions=10):
|
234 |
+
if num_questions > len(scores):
|
235 |
+
num_questions = len(scores)
|
236 |
+
print("\nWas only able to generate {} questions. For more questions, please input a longer text.".format(num_questions))
|
237 |
+
|
238 |
+
qa_list = []
|
239 |
+
for i in range(num_questions):
|
240 |
+
index = scores[i]
|
241 |
+
qa = self._make_dict(
|
242 |
+
generated_questions[index].split('?')[0] + '?',
|
243 |
+
qg_answers[index])
|
244 |
+
qa_list.append(qa)
|
245 |
+
return qa_list
|
246 |
+
|
247 |
+
def _get_all_qa_pairs(self, generated_questions, qg_answers):
|
248 |
+
qa_list = []
|
249 |
+
for i in range(len(generated_questions)):
|
250 |
+
qa = self._make_dict(
|
251 |
+
generated_questions[i].split('?')[0] + '?',
|
252 |
+
qg_answers[i])
|
253 |
+
qa_list.append(qa)
|
254 |
+
return qa_list
|
255 |
+
|
256 |
+
def _make_dict(self, question, answer):
|
257 |
+
qa = {}
|
258 |
+
qa['question'] = question
|
259 |
+
qa['answer'] = answer
|
260 |
+
return qa
|
261 |
+
|
262 |
+
|
263 |
+
class QAEvaluator():
|
264 |
+
def __init__(self, model_dir=None):
|
265 |
+
|
266 |
+
QAE_PRETRAINED = 'iarfmoose/bert-base-cased-qa-evaluator'
|
267 |
+
self.SEQ_LENGTH = 512
|
268 |
+
|
269 |
+
self.device = torch.device('cpu')
|
270 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
271 |
+
|
272 |
+
self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
|
273 |
+
self.qae_model = AutoModelForSequenceClassification.from_pretrained(QAE_PRETRAINED)
|
274 |
+
self.qae_model.to(self.device)
|
275 |
+
|
276 |
+
|
277 |
+
def encode_qa_pairs(self, questions, answers):
|
278 |
+
encoded_pairs = []
|
279 |
+
for i in range(len(questions)):
|
280 |
+
encoded_qa = self._encode_qa(questions[i], answers[i])
|
281 |
+
encoded_pairs.append(encoded_qa.to(self.device))
|
282 |
+
return encoded_pairs
|
283 |
+
|
284 |
+
def get_scores(self, encoded_qa_pairs):
|
285 |
+
scores = {}
|
286 |
+
self.qae_model.eval()
|
287 |
+
with torch.no_grad():
|
288 |
+
for i in range(len(encoded_qa_pairs)):
|
289 |
+
scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
|
290 |
+
|
291 |
+
return [k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)]
|
292 |
+
|
293 |
+
def _encode_qa(self, question, answer):
|
294 |
+
if type(answer) is list:
|
295 |
+
for a in answer:
|
296 |
+
if a['correct']:
|
297 |
+
correct_answer = a['answer']
|
298 |
+
else:
|
299 |
+
correct_answer = answer
|
300 |
+
return self.qae_tokenizer(
|
301 |
+
text=question,
|
302 |
+
text_pair=correct_answer,
|
303 |
+
pad_to_max_length=True,
|
304 |
+
max_length=self.SEQ_LENGTH,
|
305 |
+
truncation=True,
|
306 |
+
return_tensors="pt"
|
307 |
+
)
|
308 |
+
|
309 |
+
def _evaluate_qa(self, encoded_qa_pair):
|
310 |
+
output = self.qae_model(**encoded_qa_pair)
|
311 |
+
return output[0][0][1]
|
312 |
+
|
313 |
+
|
314 |
+
def print_qa(qa_list, show_answers=True):
|
315 |
+
for i in range(len(qa_list)):
|
316 |
+
space = ' ' * int(np.where(i < 9, 3, 4)) # wider space for 2 digit q nums
|
317 |
+
|
318 |
+
print('{}) Q: {}'.format(i + 1, qa_list[i]['question']))
|
319 |
+
|
320 |
+
answer = qa_list[i]['answer']
|
321 |
+
|
322 |
+
# print a list of multiple choice answers
|
323 |
+
if type(answer) is list:
|
324 |
+
|
325 |
+
if show_answers:
|
326 |
+
print('{}A: 1.'.format(space),
|
327 |
+
answer[0]['answer'],
|
328 |
+
np.where(answer[0]['correct'], '(correct)', ''))
|
329 |
+
for j in range(1, len(answer)):
|
330 |
+
print('{}{}.'.format(space + ' ', j + 1),
|
331 |
+
answer[j]['answer'],
|
332 |
+
np.where(answer[j]['correct'] == True, '(correct)', ''))
|
333 |
+
|
334 |
+
else:
|
335 |
+
print('{}A: 1.'.format(space),
|
336 |
+
answer[0]['answer'])
|
337 |
+
for j in range(1, len(answer)):
|
338 |
+
print('{}{}.'.format(space + ' ', j + 1),
|
339 |
+
answer[j]['answer'])
|
340 |
+
print('')
|
341 |
+
|
342 |
+
# print full sentence answers
|
343 |
+
else:
|
344 |
+
if show_answers:
|
345 |
+
print('{}A:'.format(space), answer, '\n')
|
requirements (1).txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
|
3 |
+
Flask==1.1.2
|
4 |
+
future==0.18.2
|
5 |
+
gradio==3.44.1
|
6 |
+
Jinja2==2.11.2
|
7 |
+
joblib==0.17.0
|
8 |
+
markupsafe==2.0.1
|
9 |
+
numpy
|
10 |
+
requests==2.24.0
|
11 |
+
sentencepiece==0.1.99
|
12 |
+
spacy
|
13 |
+
torch==2.0.1
|
14 |
+
tqdm==4.51.0
|
15 |
+
transformers==4.30.2
|
run_qg.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
from questiongenerator import QuestionGenerator
|
4 |
+
from questiongenerator import print_qa
|
5 |
+
|
6 |
+
def main():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
"--text_dir",
|
10 |
+
default=None,
|
11 |
+
type=str,
|
12 |
+
required=True,
|
13 |
+
help="The text that will be used as context for question generation.",
|
14 |
+
)
|
15 |
+
parser.add_argument(
|
16 |
+
"--model_dir",
|
17 |
+
default=None,
|
18 |
+
type=str,
|
19 |
+
help="The folder that the trained model checkpoints are in.",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--num_questions",
|
23 |
+
default=10,
|
24 |
+
type=int,
|
25 |
+
help="The desired number of questions to generate.",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--answer_style",
|
29 |
+
default="all",
|
30 |
+
type=str,
|
31 |
+
help="The desired type of answers. Choose from ['all', 'sentences', 'multiple_choice']",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--show_answers",
|
35 |
+
default='True',
|
36 |
+
type=parse_bool_string,
|
37 |
+
help="Whether or not you want the answers to be visible. Choose from ['True', 'False']",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--use_qa_eval",
|
41 |
+
default='True',
|
42 |
+
type=parse_bool_string,
|
43 |
+
help="Whether or not you want the generated questions to be filtered for quality. Choose from ['True', 'False']",
|
44 |
+
)
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
with open(args.text_dir, 'r') as file:
|
48 |
+
text_file = file.read()
|
49 |
+
|
50 |
+
qg = QuestionGenerator(args.model_dir)
|
51 |
+
|
52 |
+
qa_list = qg.generate(
|
53 |
+
text_file,
|
54 |
+
num_questions=int(args.num_questions),
|
55 |
+
answer_style=args.answer_style,
|
56 |
+
use_evaluator=args.use_qa_eval
|
57 |
+
)
|
58 |
+
print_qa(qa_list, show_answers=args.show_answers)
|
59 |
+
|
60 |
+
# taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
61 |
+
def parse_bool_string(s):
|
62 |
+
if isinstance(s, bool):
|
63 |
+
return s
|
64 |
+
if s.lower() in ('yes', 'true', 't', 'y', '1'):
|
65 |
+
return True
|
66 |
+
elif s.lower() in ('no', 'false', 'f', 'n', '0'):
|
67 |
+
return False
|
68 |
+
else:
|
69 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
main()
|