Pragnakal commited on
Commit
453b170
1 Parent(s): 42c092c

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +5 -5
  2. app.py +146 -0
  3. gitattributes +34 -0
  4. questiongenerator.py +345 -0
  5. requirements (1).txt +15 -0
  6. run_qg.py +73 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: T5 Base Question Generator
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Question Generation Using T5
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.44.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import os
4
+ import numpy as np
5
+ import pandas as pd
6
+ import json
7
+ import socket
8
+ import huggingface_hub
9
+ from huggingface_hub import Repository
10
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
11
+ from questiongenerator import QuestionGenerator
12
+ import csv
13
+ from urllib.request import urlopen
14
+ import re as r
15
+
16
+ qg = QuestionGenerator()
17
+
18
+ HF_TOKEN = os.environ.get("HF_TOKEN")
19
+ DATASET_NAME = "question_generation_T5_dataset"
20
+ DATASET_REPO_URL = f"https://huggingface.co/datasets/pragnakalp/{DATASET_NAME}"
21
+ DATA_FILENAME = "que_gen_logs.csv"
22
+ DATA_FILE = os.path.join("que_gen_logs", DATA_FILENAME)
23
+ DATASET_REPO_ID = "pragnakalp/question_generation_T5_dataset"
24
+ print("is none?", HF_TOKEN is None)
25
+ article_value = """Google was founded in 1998 by Larry Page and Sergey Brin while they were Ph.D. students at Stanford University in California. Together they own about 14 percent of its shares and control 56 percent of the stockholder voting power through supervoting stock. They incorporated Google as a privately held company on September 4, 1998. An initial public offering (IPO) took place on August 19, 2004, and Google moved to its headquarters in Mountain View, California, nicknamed the Googleplex. In August 2015, Google announced plans to reorganize its various interests as a conglomerate called Alphabet Inc. Google is Alphabet's leading subsidiary and will continue to be the umbrella company for Alphabet's Internet interests. Sundar Pichai was appointed CEO of Google, replacing Larry Page who became the CEO of Alphabet."""
26
+ # REPOSITORY_DIR = "data"
27
+ # LOCAL_DIR = 'data_local'
28
+ # os.makedirs(LOCAL_DIR,exist_ok=True)
29
+
30
+ try:
31
+ hf_hub_download(
32
+ repo_id=DATASET_REPO_ID,
33
+ filename=DATA_FILENAME,
34
+ cache_dir=DATA_DIRNAME,
35
+ force_filename=DATA_FILENAME
36
+ )
37
+
38
+ except:
39
+ print("file not found")
40
+
41
+ repo = Repository(
42
+ local_dir="que_gen_logs", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
43
+ )
44
+
45
+
46
+ def getIP():
47
+ ip_address = ''
48
+ try:
49
+ d = str(urlopen('http://checkip.dyndns.com/')
50
+ .read())
51
+
52
+ return r.compile(r'Address: (\d+\.\d+\.\d+\.\d+)').search(d).group(1)
53
+ except Exception as e:
54
+ print("Error while getting IP address -->",e)
55
+ return ip_address
56
+
57
+ def get_location(ip_addr):
58
+ location = {}
59
+ try:
60
+ ip=ip_addr
61
+
62
+ req_data={
63
+ "ip":ip,
64
+ "token":"pkml123"
65
+ }
66
+ url = "https://demos.pragnakalp.com/get-ip-location"
67
+
68
+ # req_data=json.dumps(req_data)
69
+ # print("req_data",req_data)
70
+ headers = {'Content-Type': 'application/json'}
71
+
72
+ response = requests.request("POST", url, headers=headers, data=json.dumps(req_data))
73
+ response = response.json()
74
+ print("response======>>",response)
75
+ return response
76
+ except Exception as e:
77
+ print("Error while getting location -->",e)
78
+ return location
79
+
80
+ def generate_questions(article,num_que):
81
+ result = ''
82
+ if article.strip():
83
+ if num_que == None or num_que == '':
84
+ num_que = 3
85
+ else:
86
+ num_que = num_que
87
+ generated_questions_list = qg.generate(article, num_questions=int(num_que))
88
+ summarized_data = {
89
+ "generated_questions" : generated_questions_list
90
+ }
91
+ generated_questions = summarized_data.get("generated_questions",'')
92
+
93
+ for q in generated_questions:
94
+ print(q)
95
+ result = result + q + '\n'
96
+ save_data_and_sendmail(article,generated_questions,num_que)
97
+ print("sending result***!!!!!!", result)
98
+ return result
99
+ else:
100
+ raise gr.Error("Please enter text in inputbox!!!!")
101
+
102
+ """
103
+ Save generated details
104
+ """
105
+ def save_data_and_sendmail(article,generated_questions,num_que):
106
+ try:
107
+ ip_address= getIP()
108
+ print(ip_address)
109
+ location = get_location(ip_address)
110
+ print(location)
111
+ add_csv = [article, generated_questions, num_que, ip_address,location]
112
+ print("data^^^^^",add_csv)
113
+ with open(DATA_FILE, "a") as f:
114
+ writer = csv.writer(f)
115
+ # write the data
116
+ writer.writerow(add_csv)
117
+ commit_url = repo.push_to_hub()
118
+ print("commit data :",commit_url)
119
+
120
+ url = 'https://pragnakalpdev35.pythonanywhere.com/HF_space_que_gen'
121
+ # url = 'http://pragnakalpdev33.pythonanywhere.com/HF_space_question_generator'
122
+ myobj = {'article': article,'total_que': num_que,'gen_que':generated_questions,'ip_addr':ip_address,'loc':location}
123
+ x = requests.post(url, json = myobj)
124
+ print("myobj^^^^^",myobj)
125
+
126
+ except Exception as e:
127
+ return "Error while sending mail" + str(e)
128
+
129
+ return "Successfully save data"
130
+
131
+ ## design 1
132
+ inputs=gr.Textbox(value=article_value, lines=5, label="Input Text/Article",elem_id="inp_div")
133
+ total_que = gr.Textbox(label="Number of questions to generate",elem_id="inp_div")
134
+ outputs=gr.Textbox(label="Generated Questions",lines=6,elem_id="inp_div")
135
+
136
+ demo = gr.Interface(
137
+ generate_questions,
138
+ [inputs,total_que],
139
+ outputs,
140
+ title="Question Generation Using T5-Base Model",
141
+ css=".gradio-container {background-color: lightgray} #inp_div {background-color: #7FB3D5;}",
142
+ article="""<p style='text-align: center;'>Feel free to give us your <a href="https://www.pragnakalp.com/contact/" target="_blank">feedback</a> on this Question Generation using T5 demo.</p>
143
+ <p style='text-align: center;'>Developed by: <a href="https://www.pragnakalp.com" target="_blank">Pragnakalp Techlabs</a></p>"""
144
+
145
+ )
146
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
questiongenerator.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import spacy
7
+ import re
8
+ import random
9
+ import json
10
+ import en_core_web_sm
11
+ from string import punctuation
12
+
13
+ #from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
14
+ #from transformers import BertTokenizer, BertForSequenceClassification
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
16
+ class QuestionGenerator():
17
+
18
+ def __init__(self, model_dir=None):
19
+
20
+ QG_PRETRAINED = 'iarfmoose/t5-base-question-generator'
21
+ self.ANSWER_TOKEN = '<answer>'
22
+ self.CONTEXT_TOKEN = '<context>'
23
+ self.SEQ_LENGTH = 512
24
+
25
+ self.device = torch.device('cpu')
26
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ self.qg_tokenizer = AutoTokenizer.from_pretrained(QG_PRETRAINED)
29
+ self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
30
+ self.qg_model.to(self.device)
31
+
32
+ self.qa_evaluator = QAEvaluator(model_dir)
33
+
34
+ def generate(self, article, use_evaluator=True, num_questions=None, answer_style='all'):
35
+
36
+ print("Generating questions...\n")
37
+
38
+ qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
39
+ print("qg_inputs, qg_answers=>",qg_inputs, qg_answers)
40
+ generated_questions = self.generate_questions_from_inputs(qg_inputs,num_questions)
41
+ print("generated_questions(generate)=>",generated_questions)
42
+ return generated_questions
43
+ message = "{} questions doesn't match {} answers".format(
44
+ len(generated_questions),
45
+ len(qg_answers))
46
+ assert len(generated_questions) == len(qg_answers), message
47
+
48
+ if use_evaluator:
49
+
50
+ print("Evaluating QA pairs...\n")
51
+
52
+ encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(generated_questions, qg_answers)
53
+ scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
54
+ if num_questions:
55
+ qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores, num_questions)
56
+ else:
57
+ qa_list = self._get_ranked_qa_pairs(generated_questions, qg_answers, scores)
58
+
59
+ else:
60
+ print("Skipping evaluation step.\n")
61
+ qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
62
+
63
+ return qa_list
64
+
65
+ def generate_qg_inputs(self, text, answer_style):
66
+
67
+ VALID_ANSWER_STYLES = ['all', 'sentences', 'multiple_choice']
68
+
69
+ if answer_style not in VALID_ANSWER_STYLES:
70
+ raise ValueError(
71
+ "Invalid answer style {}. Please choose from {}".format(
72
+ answer_style,
73
+ VALID_ANSWER_STYLES
74
+ )
75
+ )
76
+
77
+ inputs = []
78
+ answers = []
79
+
80
+ if answer_style == 'sentences' or answer_style == 'all':
81
+ segments = self._split_into_segments(text)
82
+ for segment in segments:
83
+ sentences = self._split_text(segment)
84
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs(sentences, segment)
85
+ inputs.extend(prepped_inputs)
86
+ answers.extend(prepped_answers)
87
+
88
+ if answer_style == 'multiple_choice' or answer_style == 'all':
89
+ sentences = self._split_text(text)
90
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(sentences)
91
+ inputs.extend(prepped_inputs)
92
+ answers.extend(prepped_answers)
93
+
94
+ return inputs, answers
95
+
96
+ def generate_questions_from_inputs(self, qg_inputs,num_questions):
97
+ generated_questions = []
98
+ count = 0
99
+ print("num que => ", num_questions)
100
+ for qg_input in qg_inputs:
101
+ if count < int(num_questions):
102
+ question = self._generate_question(qg_input)
103
+
104
+ question = question.strip() #remove trailing spaces
105
+ question = question.strip(punctuation) #remove trailing questionmarks
106
+ question += "?" #add one ?
107
+ if question not in generated_questions:
108
+ generated_questions.append(question)
109
+ print("question ===> ",question)
110
+ count += 1
111
+ else:
112
+ return generated_questions
113
+ return generated_questions #
114
+ def _split_text(self, text):
115
+ MAX_SENTENCE_LEN = 128
116
+
117
+ sentences = re.findall('.*?[.!\?]', text)
118
+
119
+ cut_sentences = []
120
+ for sentence in sentences:
121
+ if len(sentence) > MAX_SENTENCE_LEN:
122
+ cut_sentences.extend(re.split('[,;:)]', sentence))
123
+ # temporary solution to remove useless post-quote sentence fragments
124
+ cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
125
+ sentences = sentences + cut_sentences
126
+
127
+ return list(set([s.strip(" ") for s in sentences]))
128
+
129
+ def _split_into_segments(self, text):
130
+ MAX_TOKENS = 490
131
+
132
+ paragraphs = text.split('\n')
133
+ tokenized_paragraphs = [self.qg_tokenizer(p)['input_ids'] for p in paragraphs if len(p) > 0]
134
+
135
+ segments = []
136
+ while len(tokenized_paragraphs) > 0:
137
+ segment = []
138
+ while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
139
+ paragraph = tokenized_paragraphs.pop(0)
140
+ segment.extend(paragraph)
141
+ segments.append(segment)
142
+ return [self.qg_tokenizer.decode(s) for s in segments]
143
+
144
+ def _prepare_qg_inputs(self, sentences, text):
145
+ inputs = []
146
+ answers = []
147
+
148
+ for sentence in sentences:
149
+ qg_input = '{} {} {} {}'.format(
150
+ self.ANSWER_TOKEN,
151
+ sentence,
152
+ self.CONTEXT_TOKEN,
153
+ text
154
+ )
155
+ inputs.append(qg_input)
156
+ answers.append(sentence)
157
+
158
+ return inputs, answers
159
+
160
+ def _prepare_qg_inputs_MC(self, sentences):
161
+
162
+ spacy_nlp = en_core_web_sm.load()
163
+ docs = list(spacy_nlp.pipe(sentences, disable=['parser']))
164
+ inputs_from_text = []
165
+ answers_from_text = []
166
+
167
+ for i in range(len(sentences)):
168
+ entities = docs[i].ents
169
+ if entities:
170
+ for entity in entities:
171
+ qg_input = '{} {} {} {}'.format(
172
+ self.ANSWER_TOKEN,
173
+ entity,
174
+ self.CONTEXT_TOKEN,
175
+ sentences[i]
176
+ )
177
+ answers = self._get_MC_answers(entity, docs)
178
+ inputs_from_text.append(qg_input)
179
+ answers_from_text.append(answers)
180
+
181
+ return inputs_from_text, answers_from_text
182
+
183
+ def _get_MC_answers(self, correct_answer, docs):
184
+
185
+ entities = []
186
+ for doc in docs:
187
+ entities.extend([{'text': e.text, 'label_': e.label_} for e in doc.ents])
188
+
189
+ # remove duplicate elements
190
+ entities_json = [json.dumps(kv) for kv in entities]
191
+ pool = set(entities_json)
192
+ num_choices = min(4, len(pool)) - 1 # -1 because we already have the correct answer
193
+
194
+ # add the correct answer
195
+ final_choices = []
196
+ correct_label = correct_answer.label_
197
+ final_choices.append({'answer': correct_answer.text, 'correct': True})
198
+ pool.remove(json.dumps({'text': correct_answer.text, 'label_': correct_answer.label_}))
199
+
200
+ # find answers with the same NER label
201
+ matches = [e for e in pool if correct_label in e]
202
+
203
+ # if we don't have enough then add some other random answers
204
+ if len(matches) < num_choices:
205
+ choices = matches
206
+ pool = pool.difference(set(choices))
207
+ choices.extend(random.sample(pool, num_choices - len(choices)))
208
+ else:
209
+ choices = random.sample(matches, num_choices)
210
+
211
+ choices = [json.loads(s) for s in choices]
212
+ for choice in choices:
213
+ final_choices.append({'answer': choice['text'], 'correct': False})
214
+ random.shuffle(final_choices)
215
+ return final_choices
216
+
217
+ def _generate_question(self, qg_input):
218
+ self.qg_model.eval()
219
+ encoded_input = self._encode_qg_input(qg_input)
220
+ with torch.no_grad():
221
+ output = self.qg_model.generate(input_ids=encoded_input['input_ids'])
222
+ return self.qg_tokenizer.decode(output[0])
223
+
224
+ def _encode_qg_input(self, qg_input):
225
+ return self.qg_tokenizer(
226
+ qg_input,
227
+ pad_to_max_length=True,
228
+ max_length=self.SEQ_LENGTH,
229
+ truncation=True,
230
+ return_tensors="pt"
231
+ ).to(self.device)
232
+
233
+ def _get_ranked_qa_pairs(self, generated_questions, qg_answers, scores, num_questions=10):
234
+ if num_questions > len(scores):
235
+ num_questions = len(scores)
236
+ print("\nWas only able to generate {} questions. For more questions, please input a longer text.".format(num_questions))
237
+
238
+ qa_list = []
239
+ for i in range(num_questions):
240
+ index = scores[i]
241
+ qa = self._make_dict(
242
+ generated_questions[index].split('?')[0] + '?',
243
+ qg_answers[index])
244
+ qa_list.append(qa)
245
+ return qa_list
246
+
247
+ def _get_all_qa_pairs(self, generated_questions, qg_answers):
248
+ qa_list = []
249
+ for i in range(len(generated_questions)):
250
+ qa = self._make_dict(
251
+ generated_questions[i].split('?')[0] + '?',
252
+ qg_answers[i])
253
+ qa_list.append(qa)
254
+ return qa_list
255
+
256
+ def _make_dict(self, question, answer):
257
+ qa = {}
258
+ qa['question'] = question
259
+ qa['answer'] = answer
260
+ return qa
261
+
262
+
263
+ class QAEvaluator():
264
+ def __init__(self, model_dir=None):
265
+
266
+ QAE_PRETRAINED = 'iarfmoose/bert-base-cased-qa-evaluator'
267
+ self.SEQ_LENGTH = 512
268
+
269
+ self.device = torch.device('cpu')
270
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
271
+
272
+ self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
273
+ self.qae_model = AutoModelForSequenceClassification.from_pretrained(QAE_PRETRAINED)
274
+ self.qae_model.to(self.device)
275
+
276
+
277
+ def encode_qa_pairs(self, questions, answers):
278
+ encoded_pairs = []
279
+ for i in range(len(questions)):
280
+ encoded_qa = self._encode_qa(questions[i], answers[i])
281
+ encoded_pairs.append(encoded_qa.to(self.device))
282
+ return encoded_pairs
283
+
284
+ def get_scores(self, encoded_qa_pairs):
285
+ scores = {}
286
+ self.qae_model.eval()
287
+ with torch.no_grad():
288
+ for i in range(len(encoded_qa_pairs)):
289
+ scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
290
+
291
+ return [k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)]
292
+
293
+ def _encode_qa(self, question, answer):
294
+ if type(answer) is list:
295
+ for a in answer:
296
+ if a['correct']:
297
+ correct_answer = a['answer']
298
+ else:
299
+ correct_answer = answer
300
+ return self.qae_tokenizer(
301
+ text=question,
302
+ text_pair=correct_answer,
303
+ pad_to_max_length=True,
304
+ max_length=self.SEQ_LENGTH,
305
+ truncation=True,
306
+ return_tensors="pt"
307
+ )
308
+
309
+ def _evaluate_qa(self, encoded_qa_pair):
310
+ output = self.qae_model(**encoded_qa_pair)
311
+ return output[0][0][1]
312
+
313
+
314
+ def print_qa(qa_list, show_answers=True):
315
+ for i in range(len(qa_list)):
316
+ space = ' ' * int(np.where(i < 9, 3, 4)) # wider space for 2 digit q nums
317
+
318
+ print('{}) Q: {}'.format(i + 1, qa_list[i]['question']))
319
+
320
+ answer = qa_list[i]['answer']
321
+
322
+ # print a list of multiple choice answers
323
+ if type(answer) is list:
324
+
325
+ if show_answers:
326
+ print('{}A: 1.'.format(space),
327
+ answer[0]['answer'],
328
+ np.where(answer[0]['correct'], '(correct)', ''))
329
+ for j in range(1, len(answer)):
330
+ print('{}{}.'.format(space + ' ', j + 1),
331
+ answer[j]['answer'],
332
+ np.where(answer[j]['correct'] == True, '(correct)', ''))
333
+
334
+ else:
335
+ print('{}A: 1.'.format(space),
336
+ answer[0]['answer'])
337
+ for j in range(1, len(answer)):
338
+ print('{}{}.'.format(space + ' ', j + 1),
339
+ answer[j]['answer'])
340
+ print('')
341
+
342
+ # print full sentence answers
343
+ else:
344
+ if show_answers:
345
+ print('{}A:'.format(space), answer, '\n')
requirements (1).txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
3
+ Flask==1.1.2
4
+ future==0.18.2
5
+ gradio==3.44.1
6
+ Jinja2==2.11.2
7
+ joblib==0.17.0
8
+ markupsafe==2.0.1
9
+ numpy
10
+ requests==2.24.0
11
+ sentencepiece==0.1.99
12
+ spacy
13
+ torch==2.0.1
14
+ tqdm==4.51.0
15
+ transformers==4.30.2
run_qg.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ from questiongenerator import QuestionGenerator
4
+ from questiongenerator import print_qa
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--text_dir",
10
+ default=None,
11
+ type=str,
12
+ required=True,
13
+ help="The text that will be used as context for question generation.",
14
+ )
15
+ parser.add_argument(
16
+ "--model_dir",
17
+ default=None,
18
+ type=str,
19
+ help="The folder that the trained model checkpoints are in.",
20
+ )
21
+ parser.add_argument(
22
+ "--num_questions",
23
+ default=10,
24
+ type=int,
25
+ help="The desired number of questions to generate.",
26
+ )
27
+ parser.add_argument(
28
+ "--answer_style",
29
+ default="all",
30
+ type=str,
31
+ help="The desired type of answers. Choose from ['all', 'sentences', 'multiple_choice']",
32
+ )
33
+ parser.add_argument(
34
+ "--show_answers",
35
+ default='True',
36
+ type=parse_bool_string,
37
+ help="Whether or not you want the answers to be visible. Choose from ['True', 'False']",
38
+ )
39
+ parser.add_argument(
40
+ "--use_qa_eval",
41
+ default='True',
42
+ type=parse_bool_string,
43
+ help="Whether or not you want the generated questions to be filtered for quality. Choose from ['True', 'False']",
44
+ )
45
+ args = parser.parse_args()
46
+
47
+ with open(args.text_dir, 'r') as file:
48
+ text_file = file.read()
49
+
50
+ qg = QuestionGenerator(args.model_dir)
51
+
52
+ qa_list = qg.generate(
53
+ text_file,
54
+ num_questions=int(args.num_questions),
55
+ answer_style=args.answer_style,
56
+ use_evaluator=args.use_qa_eval
57
+ )
58
+ print_qa(qa_list, show_answers=args.show_answers)
59
+
60
+ # taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
61
+ def parse_bool_string(s):
62
+ if isinstance(s, bool):
63
+ return s
64
+ if s.lower() in ('yes', 'true', 't', 'y', '1'):
65
+ return True
66
+ elif s.lower() in ('no', 'false', 'f', 'n', '0'):
67
+ return False
68
+ else:
69
+ raise argparse.ArgumentTypeError('Boolean value expected.')
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()