Anni123 commited on
Commit
49079cf
1 Parent(s): 7ead7d1

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +347 -45
  2. llm_utils.py +78 -0
  3. retrieval_utils.py +246 -0
  4. utils.py +68 -0
app.py CHANGED
@@ -1,49 +1,351 @@
1
  import gradio as gr
2
- from statistics import mean
3
- from torch.utils.data import Dataset
4
- from collections import OrderedDict
5
- import xml.etree.ElementTree as ET
6
  import openai # For GPT-3 API ...
7
- import os
8
- import multiprocessing
9
- import json
10
- import numpy as np
11
- import random
12
- import torch
13
- import torchtext
14
  import re
15
- import random
16
- import time
17
- import datetime
18
- import pandas as pd
19
- import sys
20
-
21
- openai.api_key = os.getenv("api_key")
22
-
23
- def greet(question):
24
- input = question + '\n\n' + "|step|subquestion|process|result|"
25
- response = openai.ChatCompletion.create(
26
- model="gpt-3.5-turbo",
27
- messages=[
28
- {"role": "system", "content": "You are a helpful assistant that generate table to solve reasoning problem."},
29
- {"role": "user", "content": input},
30
-
31
- ]
32
- )
33
- response = response["choices"][0]["message"]["content"]
34
- return "|step|subquestion|process|result|\n" + response
35
-
36
-
37
- iface = gr.Interface(
38
- fn=greet,
39
- inputs="text",
40
- outputs="text",
41
- title="Tab-CoT: Zero-Shot Tabular Chain-of-Thought",
42
- examples=[
43
- ["Tommy is fundraising for his charity by selling brownies for $3 a slice and cheesecakes for $4 a slice. If Tommy sells 43 brownies and 23 slices of cheesecake, how much money does Tommy raise?"],
44
- ["Judy teaches 5 dance classes, every day, on the weekdays and 8 classes on Saturday. If each class has 15 students and she charges $15.00 per student, how much money does she make in 1 week?"],
45
- ["According to its nutritional info, a bag of chips has 250 calories per serving. If a 300g bag has 5 servings, how many grams can you eat if your daily calorie target is 2000 and you have already consumed 1800 calories?"],
46
- ]
47
- )
48
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
1
  import gradio as gr
 
 
 
 
2
  import openai # For GPT-3 API ...
 
 
 
 
 
 
 
3
  import re
4
+ import threading
5
+ import json
6
+ from collections import Counter
7
+ from llm_utils import *
8
+ from utils import *
9
+ from retrieval_utils import *
10
+
11
+ openai.api_key = "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"
12
+
13
+ COT_PROMPT = "Let's think step by step."
14
+ DIRECT_ANS_PROMPT = "The answer is"
15
+
16
+ #EXAMPLES = {
17
+ # 'arithmetic': ['Marco and his dad went strawberry picking. Together they collected strawberries that weighed 36 pounds. On the way back Marco \' dad lost 8 pounds of strawberries. Marco\'s strawberries now weighed 12 pounds. How much did his dad\'s strawberries weigh now?'],
18
+ # 'commonsense-verify': [['is the brain located in the torso?'], ['Is entire Common Era minuscule to lifespan of some trees?'], ['Did the Football War last at least a month?']],
19
+ # 'commonsens-mc': ['What would someone use a personal key for? Answer Choices: (A) car stand (B) at hotel (C) own home (D) front door (E) bus depot', ],
20
+ # 'symbolic-letter': ['Take the last letters of each words in \"Kristopher Deb Jake Tammy\" and concatenate them.'],
21
+ # 'symbolic-coin': ['A coin is heads up. Isela flips the coin. Leslie flips the coin. Stacy flips the coin. Ingrid does not flip the coin. Is the coin still heads up? Note that \"flip\" here means \"reverse\".']
22
+ #}
23
+
24
+ EXAMPLES = ['Take the last letters of each words in \"Kristopher Deb Jake Tammy\" and concatenate them.', \
25
+ 'is the brain located in the torso?', 'Is entire Common Era minuscule to lifespan of some trees?', 'Did the Football War last at least a month?', \
26
+ 'What would someone use a personal key for? Answer Choices: (A) car stand (B) at hotel (C) own home (D) front door (E) bus depot', \
27
+ 'A coin is heads up. Isela flips the coin. Leslie flips the coin. Stacy flips the coin. Ingrid does not flip the coin. Is the coin still heads up? Note that \"flip\" here means \"reverse\".', \
28
+ 'Marco and his dad went strawberry picking. Together they collected strawberries that weighed 36 pounds. On the way back Marco \' dad lost 8 pounds of strawberries. Marco\'s strawberries now weighed 12 pounds. How much did his dad\'s strawberries weigh now?']
29
+
30
+
31
+
32
+ global lock #global lock, repo
33
+ lock = threading.Lock()
34
+
35
+ def answer_extraction_prompt(datatype):
36
+ if datatype == "commonsense-mc":
37
+ ans_prompt = "\nTherefore, among A through E, the answer is"
38
+ elif datatype == "commonsense-verify":
39
+ ans_prompt = "\nTherefore, the answer (Yes or No) is"
40
+ elif datatype == "arithmetic":
41
+ ans_prompt = "\nTherefore, the answer (arabic numerals) is"
42
+ elif datatype == "symbolic-letter":
43
+ ans_prompt = "\nTherefore, the answer is"
44
+ elif datatype == "symbolic-coin":
45
+ ans_prompt = "\nTherefore, the answer (Yes or No) is"
46
+ else: #if datatype == "Undefined"
47
+ ans_prompt = "\nTherefore, the answer is"
48
+ return ans_prompt
49
+
50
+
51
+ def zero_shot(datatype, question, engine):
52
+ ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype)
53
+ ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT.replace("\nTherefore, ", "")
54
+ ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT[0].upper() + ANS_EXTRACTION_PROMPT[1:]
55
+ input = "Q: " + question + "\n" + "A: " + ANS_EXTRACTION_PROMPT
56
+ ans_response = decoder_for_gpt3(input, max_length=32, engine=engine)
57
+ ans_response = answer_cleansing_zero_shot(datatype, ans_response)
58
+ if ans_response == "":
59
+ ans_response = "VOID"
60
+ return ans_response
61
+
62
+
63
+
64
+ def highlight_knowledge(entities, retrieved_knowledge):
65
+ str_md = retrieved_knowledge
66
+ for ent in entities:
67
+ ent_md = {}
68
+ m_pos = re.finditer(ent, retrieved_knowledge, re.IGNORECASE) #[(s,e),(s,e)]
69
+ for m in m_pos:
70
+ s, e = m.start(), m.end()
71
+ if retrieved_knowledge[s:e] not in ent_md.keys():
72
+ ent_ = retrieved_knowledge[s:e]
73
+ ent_md[ent_] = '<span style="background-color: lightcoral"> **' + ent_ + '** </span>'
74
+ for e_ori, e_md in ent_md.items():
75
+ print(e_ori)
76
+ print(e_md)
77
+ str_md = str_md.replace(e_ori, e_md)
78
+ return str_md
79
+
80
+ def zero_cot_consi(question, engine):
81
+ input = "Q: " + question + "\n" + "A: " + COT_PROMPT
82
+ cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) #list of cots
83
+ return cot_responses
84
+
85
+ def auto_cot_consi(question, demo_text, engine):
86
+ input = demo_text + "Q: " + question + "\n" + "A: " + COT_PROMPT
87
+ cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) #list of cots
88
+ return cot_responses
89
+
90
+
91
+ def cot_revision(datatype, question, ori_cots, knowledge, engine):
92
+ ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype)
93
+ corrected_rationales = []
94
+ corrected_answers = []
95
+ correction_prompt = "Question: " + "[ " + question + "]\n"
96
+ correction_prompt += "Knowledge: " + "[ " + knowledge + "]\n"
97
+ for ori_r in ori_cots:
98
+ cor_p = correction_prompt + "Original rationale: " + "[ " + ori_r + "]\n"
99
+ cor_p += "With Knowledge given, output the revised rationale for Question in a precise and certain style by thinking step by step: "
100
+ corrected_rationale = decoder_for_gpt3(cor_p,max_length=256, temperature=0.7, engine=engine)
101
+ corrected_rationale = corrected_rationale.strip()
102
+ corrected_rationales.append(corrected_rationale)
103
+ input = "Q: " + question + "\n" + "A: " + corrected_rationale + ANS_EXTRACTION_PROMPT
104
+ ans = decoder_for_gpt3(input, max_length=32, temperature=0.7, engine=engine)
105
+ ans = answer_cleansing_zero_shot(datatype, ans)
106
+ corrected_answers.append(ans)
107
+ return corrected_rationales, corrected_answers
108
+
109
+
110
+ def consistency(arr):
111
+ len_ans = len(arr)
112
+ arr_acounts = Counter(arr)
113
+ ans_freq_tuple = arr_acounts.most_common(len_ans)
114
+ most_frequent_item, _ = ans_freq_tuple[0]
115
+ ans_dict = {}
116
+ for ans_freq in ans_freq_tuple:
117
+ ans, times = ans_freq
118
+ ans_dict[ans] = times/len_ans
119
+ return most_frequent_item, ans_dict
120
+
121
+
122
+ ## todo: git pull
123
+ def record_feedback(single_data, feedback, store_flag):
124
+ global lock
125
+ print(f"Logging feedback...")
126
+ datatype = single_data['datatype']
127
+ data_dir = './data_pool/{dataname}_feedback'.format(dataname=datatype)
128
+
129
+ lock.acquire()
130
+ if store_flag:
131
+ single_data.update({'feedback':feedback})
132
+ with open(data_dir, "a") as f:
133
+ data_json = json.dumps(single_data)
134
+ f.write(data_json + "\n")
135
+ lock.release()
136
+ print(f"Logging finished...")
137
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
138
+ gr.update(value="😃 Thank you for your valuable feedback!")
139
+
140
+
141
+ def record_feedback_agree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
142
+ single_data = {
143
+ 'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
144
+ 'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
145
+ 'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
146
+ return record_feedback(single_data, 'agree', store_flag)
147
+ def record_feedback_disagree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
148
+ single_data = {
149
+ 'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
150
+ 'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
151
+ 'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
152
+ return record_feedback(single_data, "disagree", store_flag)
153
+ def record_feedback_uncertain(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
154
+ single_data = {
155
+ 'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
156
+ 'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
157
+ 'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
158
+ return record_feedback(single_data, 'uncertain', store_flag)
159
+
160
+ def reset():
161
+ return gr.update(value=""), gr.update(value=""), \
162
+ gr.update(visible=False), gr.update(value="", label=""), gr.update(value="", label=""), gr.update(value="", label=""), \
163
+ gr.update(value=""), gr.update(value=""), gr.update(value=""), gr.update(value="")
164
+
165
+
166
+ def identify_type(question, engine):
167
+ with open('./demos/type', 'r') as f:
168
+ typedemo = f.read()
169
+ typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: "
170
+ response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine)
171
+ response = response.strip().lower()
172
+ response = type_cleasing(response)
173
+ return response
174
+
175
+ def load_examples(datatype):
176
+ return gr.update(examples=EXAMPLES[datatype])
177
+
178
+
179
+ def self_construction(datatype):
180
+ if datatype == "arithmetic":
181
+ fig_adr = './figs/multiarith.png'
182
+ demo_path = './demos/multiarith'
183
+ elif datatype == "commonsense-mc":
184
+ fig_adr = './figs/commonsensqa.png'
185
+ demo_path = './demos/commonsensqa'
186
+ elif datatype == "commonsense-verify":
187
+ fig_adr = './figs/strategyqa.png'
188
+ demo_path = './demos/strategyqa'
189
+ elif datatype == "symbolic-coin":
190
+ fig_adr = './figs/coin_flip.png'
191
+ demo_path = './demos/coin_flip'
192
+ elif datatype == "symbolic-letter":
193
+ fig_adr = './figs/last_letters.png'
194
+ demo_path = './demos/last_letters'
195
+ else:
196
+ pass ##todo: datatype == 'UNDEFINED'
197
+
198
+ ##读取对应的demo
199
+ x, z, y =[], [], []
200
+ with open(demo_path, encoding="utf-8") as f:
201
+ json_data = json.load(f)
202
+ json_data = json_data["demo"]
203
+ for line in json_data:
204
+ x.append(line["question"])
205
+ z.append(line["rationale"])
206
+ y.append(line["pred_ans"])
207
+ index_list = list(range(len(x)))
208
+
209
+ demo_md, demo_text = "", ""
210
+ for i in index_list:
211
+ demo_text += x[i] + " " + z[i] + " " + \
212
+ DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n"
213
+ demo_md += '<span style="background-color: #E0A182">' + "Q: "+ '</span>' + x[i][3:-3] + \
214
+ "<br>" + '<span style="background-color: #DD97AF">' + "A: "+ '</span>' + z[i] + " " + \
215
+ DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n"
216
+
217
+
218
+ return gr.update(value="## 🔭 Self construction..."), gr.update(visible=True, label="Visualization of clustering", value=fig_adr), \
219
+ gr.update(visible=True, value=demo_md), gr.update(value=demo_text), \
220
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
221
+
222
+ def self_retrieval(input_question, engine):
223
+ entities, self_retrieve_knowledge, kb_retrieve_knowledge = retrieve_for_question(input_question, engine)
224
+
225
+ entities_string = ", ".join(entities)
226
+ retr_md = "### ENTITIES:" + "<br>" + "> "+ entities_string + "\n\n"
227
+ retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities,self_retrieve_knowledge) + "\n\n"
228
+ retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n"
229
+
230
+ return gr.update(value="## 📚 Self retrieval..."), gr.update(visible=True, label="", value='./figs/self-retrieval.png'), \
231
+ gr.update(value=retr_md), \
232
+ gr.update(value=entities_string), gr.update(value=self_retrieve_knowledge), gr.update(value=kb_retrieve_knowledge), \
233
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
234
+
235
+ def self_refinement(input_question, entities, self_retrieve_knowledge, kb_retrieve_knowledge, engine):
236
+ refine_knowledge = refine_for_question(input_question, engine, self_retrieve_knowledge, kb_retrieve_knowledge)
237
+
238
+ retr_md = "### ENTITIES:" + "<br>" + "> " + entities + "\n\n"
239
+ entities = entities.strip().strip('<p>').strip('</p>').split(", ")
240
+ retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, self_retrieve_knowledge) + "\n\n"
241
+ retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n"
242
+ refine_md = retr_md + "### REFINED-KNOWLEDGE:" + "<br>" + "> "
243
+ refine_md += highlight_knowledge(entities, refine_knowledge)
244
+
245
+
246
+ return gr.update(value="## 🪄 Self refinement..."), gr.update(visible=True, label="", value='./figs/self-refinement.png'), \
247
+ gr.update(value=refine_md), gr.update(value=refine_knowledge), \
248
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
249
+
250
+ def self_revision(input_question, datatype, demo_text, refined_knowledge, engine):
251
+ print(demo_text)
252
+ print(refined_knowledge)
253
+ ori_cots = auto_cot_consi(input_question, demo_text, engine)
254
+ cor_cots, cor_ans = cot_revision(datatype, input_question, ori_cots, refined_knowledge, engine)
255
+ cor_cots_md = "### Revised Rationales:" + "\n\n"
256
+ for cor_cot in cor_cots:
257
+ cor_cots_md += "> " + cor_cot + "\n\n"
258
+ cor_ans = ", ".join(cor_ans)
259
+
260
+ return gr.update(value="## 🔧 Self revision..."), gr.update(visible=True, label="", value='./figs/self-revision.png'), \
261
+ gr.update(value=cor_cots_md), gr.update(value=cor_ans), \
262
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
263
+
264
+ def self_consistency(cor_ans, datatype, question, engine):
265
+ cor_ans = cor_ans.strip().split(", ")
266
+ our_ans, ans_dict = consistency(cor_ans)
267
+ zeroshot_ans = zero_shot(datatype, question, engine)
268
+
269
+ return gr.update(value="## 🗳 Self consistency..."), gr.update(visible=True, label="", value='./figs/self-consistency.png'), \
270
+ gr.update(value=""), gr.update(value=ans_dict, visible=True), \
271
+ gr.update(visible=True, value=our_ans), gr.update(visible=True, value=zeroshot_ans), \
272
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
273
+ gr.update(visible=True, value='We would appreciate it very much if you could share your feedback. ')
274
+
275
+
276
+ def reset():
277
+ return gr.update(value=""), gr.update(value=""), gr.update(value=""), \
278
+ gr.update(visible=False), gr.update(value=""), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
279
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="")
280
+
281
+ #theme from: https://huggingface.co/spaces/gradio/theme-gallery
282
+ #EveryPizza/Cartoony-Gradio-Theme
283
+ #JohnSmith9982/small_and_pretty
284
+ #bethecloud/storj_theme
285
+ #gradio/soft
286
+ with gr.Blocks(theme="bethecloud/storj_theme", css="#process_btn {background-color:#8BA3C5}") as demo:
287
+ gr.Markdown("# 🌟 通用自适应的推理增强系统 (Unified-Adapter) 🌟")
288
+ with gr.Row():
289
+ with gr.Column(scale=4):
290
+ input_question = gr.Textbox(placeholder="Input question here, or select an example from below.", label="Input Question",lines=2)
291
+ store_flag = gr.Checkbox(label="Store data",value=True, interactive=True, info="If you agree to store data for research and development use:")
292
+ single_data = gr.JSON(visible=False)
293
+ with gr.Column(scale=3):
294
+ engine = gr.Dropdown(choices=['gpt-3.5-turbo','text-davinci-003', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001'],
295
+ label="Engine", value="text-davinci-003", interactive=True, info="Choose the engine and have a try!")
296
+ reset_btn = gr.Button(value='RESET')
297
+ examples = gr.Examples(examples=EXAMPLES, inputs=[input_question])
298
+
299
+ with gr.Row():
300
+ with gr.Column(scale=1):
301
+ type_btn = gr.Button(value="Self-identification", variant='primary', scale=1, elem_id="process_btn")
302
+ with gr.Column(scale=3):
303
+ datatype = gr.Dropdown(choices=['arithmetic','commonsense-mc','commonsense-verify','symbolic-letter','symbolic-coin','UNDEFINED'],
304
+ label="Input Type", info="If you disagree with our output, please select manually.", scale=3)
305
+
306
+ demo_text = gr.Textbox(visible=False)
307
+ entities = gr.Textbox(visible=False)
308
+ self_know = gr.Textbox(visible=False)
309
+ kb_know = gr.Textbox(visible=False)
310
+ refine_know = gr.Textbox(visible=False)
311
+ cor_ans = gr.Textbox(visible=False)
312
+ with gr.Row():
313
+ const_btn = gr.Button(value='Self-construction', variant='primary', elem_id="process_btn")
314
+ retr_btn = gr.Button(value='Self-retrieval', variant='primary', elem_id="process_btn")
315
+ refine_btn = gr.Button(value='Self-refinement', variant='primary', elem_id="process_btn")
316
+ revis_btn = gr.Button(value='Self-revision', variant='primary', elem_id="process_btn")
317
+ consis_btn = gr.Button(value='Self-consistency', variant='primary', elem_id="process_btn")
318
+
319
+ sub_title = gr.Markdown()
320
+ with gr.Row():
321
+ with gr.Column(scale=2):
322
+ plot = gr.Image(label="Visualization of clustering", visible=False)
323
+ with gr.Column(scale=3):
324
+ md = gr.Markdown()
325
+ label = gr.Label(visible=False, label="Consistency Predictions")
326
+ ans_ours = gr.Textbox(label="Unified-Adapter Answer",visible=False)
327
+ ans_zeroshot = gr.Textbox(label="Zero-shot Answer", visible=False)
328
+ with gr.Row():
329
+ feedback_agree = gr.Button(value='😊 Agree', variant='secondary', visible=False)
330
+ feedback_disagree = gr.Button(value='🙁 Disagree', variant='secondary', visible=False)
331
+ feedback_uncertain = gr.Button(value='🤔 Uncertain', variant='secondary', visible=False)
332
+ feedback_ack = gr.Markdown(value='', visible=True, interactive=False)
333
+
334
+
335
+ type_btn.click(identify_type, inputs=[input_question, engine], outputs=[datatype])
336
+ const_btn.click(self_construction, inputs=[datatype], outputs=[sub_title, plot, md, demo_text, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
337
+ retr_btn.click(self_retrieval, inputs=[input_question, engine], outputs=[sub_title, plot, md, entities, self_know, kb_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
338
+ refine_btn.click(self_refinement, inputs=[input_question, entities, self_know, kb_know, engine], outputs=[sub_title, plot, md, refine_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
339
+ revis_btn.click(self_revision, inputs=[input_question, datatype, demo_text, refine_know, engine], outputs=[sub_title, plot, md, cor_ans, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
340
+ consis_btn.click(self_consistency, inputs=[cor_ans, datatype, input_question, engine], outputs=[sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
341
+ reset_btn.click(reset, inputs=[], outputs=[input_question, datatype, sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
342
+
343
+ feedback_agree.click(record_feedback_agree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
344
+ feedback_disagree.click(record_feedback_disagree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
345
+ feedback_uncertain.click(record_feedback_uncertain, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
346
+
347
+
348
+ demo.launch()
349
+
350
+
351
 
llm_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import openai
3
+
4
+ #openai.api_key = "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"
5
+
6
+
7
+ # Sentence Generator (Decoder) for GPT-3 ...
8
+ def decoder_for_gpt3(input, max_length, temperature=0, engine="text-davinci-003"):
9
+ # GPT-3 API allows each users execute the API within 60 times in a minute ...
10
+ if engine == "gpt-3.5-turbo":
11
+ time.sleep(1)
12
+ response = openai.ChatCompletion.create(
13
+ model=engine,
14
+ messages=[
15
+ #{"role": "system", "content": "You need to answer commonsense questions."},
16
+ {"role": "user", "content": input}
17
+ ],
18
+ max_tokens=max_length,
19
+ temperature=temperature,
20
+ stop=None
21
+ )
22
+ response = response["choices"][0]["message"]["content"]
23
+
24
+ else:
25
+ time.sleep(1)
26
+ response = openai.Completion.create(
27
+ model=engine,
28
+ prompt=input,
29
+ max_tokens=max_length,
30
+ stop=None,
31
+ temperature=temperature
32
+ )
33
+ response = response["choices"][0]["text"]
34
+ return response
35
+
36
+ def decoder_for_gpt3_consistency(input, max_length, temp=0.7, n=5, engine="text-davinci-003"):
37
+ # GPT-3 API allows each users execute the API within 60 times in a minute ...
38
+ if engine == "gpt-3.5-turbo":
39
+ time.sleep(1)
40
+ responses = openai.ChatCompletion.create(
41
+ model=engine,
42
+ messages=[
43
+ {"role": "user", "content": input}
44
+ ],
45
+ max_tokens=max_length,
46
+ temperature=temp,
47
+ top_p=1,
48
+ n=5,
49
+ stop=["\n"],
50
+ )
51
+ responses = [responses["choices"][i]["message"]["content"] for i in range(n)]
52
+ else:
53
+ time.sleep(1)
54
+ responses = openai.Completion.create(
55
+ model=engine,
56
+ prompt=input,
57
+ max_tokens=max_length,
58
+ temperature=temp,
59
+ stop=["\n"],
60
+ n=5,
61
+ logprobs=5,
62
+ top_p=1,
63
+ )
64
+ responses = [responses["choices"][i]["text"] for i in range(n)]
65
+
66
+ return responses
67
+
68
+ def zero_shot(question):
69
+ input = question + " " + "Among A through E, the answer is"
70
+ response = openai.ChatCompletion.create(
71
+ model="gpt-3.5-turbo",
72
+ messages=[
73
+ {"role": "system", "content": "You are a helpful assistant that answer commonsense questions."},
74
+ {"role": "user", "content": input}
75
+ ]
76
+ )
77
+ response = response["choices"][0]["message"]["content"]
78
+ return response
retrieval_utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Modified from https://github.com/RuochenZhao/Verify-and-Edit
3
+ '''
4
+
5
+ import wikipedia
6
+ import wikipediaapi
7
+ import spacy
8
+ import numpy as np
9
+ import ngram
10
+ #import nltk
11
+ import torch
12
+ import sklearn
13
+ #from textblob import TextBlob
14
+ from nltk import tokenize
15
+ from sentence_transformers import SentenceTransformer
16
+ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer
17
+ from llm_utils import decoder_for_gpt3
18
+ from utils import entity_cleansing, knowledge_cleansing
19
+
20
+ wiki_wiki = wikipediaapi.Wikipedia('en')
21
+ nlp = spacy.load("en_core_web_sm")
22
+ ENT_TYPE = ['EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'NORP', 'ORG', 'PERSON', 'PRODUCT', 'WORK_OF_ART']
23
+
24
+ CTX_ENCODER = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
25
+ CTX_TOKENIZER = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", model_max_length = 512)
26
+ Q_ENCODER = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
27
+ Q_TOKENIZER = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base", model_max_length = 512)
28
+
29
+
30
+ ## todo: extract entities from ConceptNet
31
+ def find_ents(text, engine):
32
+ doc = nlp(text)
33
+ valid_ents = []
34
+ for ent in doc.ents:
35
+ if ent.label_ in ENT_TYPE:
36
+ valid_ents.append(ent.text)
37
+ #in case entity list is empty: resort to LLM to extract entity
38
+ if valid_ents == []:
39
+ input = "Question: " + "[ " + text + "]\n"
40
+ input += "Output the entities in Question separated by comma: "
41
+ response = decoder_for_gpt3(input, 32, engine=engine)
42
+ valid_ents = entity_cleansing(response)
43
+ return valid_ents
44
+
45
+
46
+ def relevant_pages_for_ents(valid_ents, topk = 5):
47
+ '''
48
+ Input: a list of valid entities
49
+ Output: a list of list containing topk pages for each entity
50
+ '''
51
+ if valid_ents == []:
52
+ return []
53
+ titles = []
54
+ for ve in valid_ents:
55
+ title = wikipedia.search(ve)[:topk]
56
+ titles.append(title)
57
+ #titles = list(dict.fromkeys(titles))
58
+ return titles
59
+
60
+
61
+ def relevant_pages_for_text(text, topk = 5):
62
+ return wikipedia.search(text)[:topk]
63
+
64
+
65
+ def get_wiki_objs(pages):
66
+ '''
67
+ Input: a list of list
68
+ Output: a list of list
69
+ '''
70
+ if pages == []:
71
+ return []
72
+ obj_pages = []
73
+ for titles_for_ve in pages:
74
+ pages_for_ve = [wiki_wiki.page(title) for title in titles_for_ve]
75
+ obj_pages.append(pages_for_ve)
76
+ return obj_pages
77
+
78
+
79
+ def get_linked_pages(wiki_pages, topk = 5):
80
+ linked_ents = []
81
+ for wp in wiki_pages:
82
+ linked_ents += list(wp.links.values())
83
+ if topk != -1:
84
+ linked_ents = linked_ents[:topk]
85
+ return linked_ents
86
+
87
+
88
+ def get_texts_to_pages(pages, topk = 2):
89
+ '''
90
+ Input: list of list of pages
91
+ Output: list of list of texts
92
+ '''
93
+ total_texts = []
94
+ for ve_pages in pages:
95
+ ve_texts = []
96
+ for p in ve_pages:
97
+ text = p.text
98
+ text = tokenize.sent_tokenize(text)[:topk]
99
+ text = ' '.join(text)
100
+ ve_texts.append(text)
101
+ total_texts.append(ve_texts)
102
+ return total_texts
103
+
104
+
105
+
106
+ def DPR_embeddings(q_encoder, q_tokenizer, question):
107
+ question_embedding = q_tokenizer(question, return_tensors="pt",max_length=5, truncation=True)
108
+ with torch.no_grad():
109
+ try:
110
+ question_embedding = q_encoder(**question_embedding)[0][0]
111
+ except:
112
+ print(question)
113
+ print(question_embedding['input_ids'].size())
114
+ raise Exception('end')
115
+ question_embedding = question_embedding.numpy()
116
+ return question_embedding
117
+
118
+ def model_embeddings(sentence, model):
119
+ embedding = model.encode([sentence])
120
+ return embedding[0] #should return an array of shape 384
121
+
122
+ ##todo: plus overlap filtering
123
+ def filtering_retrieved_texts(question, ent_texts, retr_method="wikipedia_dpr", topk=1):
124
+ filtered_texts = []
125
+ for texts in ent_texts:
126
+ if texts != []: #not empty list
127
+ if retr_method == "ngram":
128
+ pars = np.array([ngram.NGram.compare(question, sent, N=1) for sent in texts])
129
+ #argsort: smallest to biggest
130
+ pars = pars.argsort()[::-1][:topk]
131
+ else:
132
+ if retr_method == "wikipedia_dpr":
133
+ sen_embeds = [DPR_embeddings(Q_ENCODER, Q_TOKENIZER, question)]
134
+ par_embeds = [DPR_embeddings(CTX_ENCODER, CTX_TOKENIZER, s) for s in texts]
135
+ else:
136
+ embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
137
+ sen_embeds = [model_embeddings(question, embedding_model)]
138
+ par_embeds = [model_embeddings(s, embedding_model) for s in texts]
139
+ pars = sklearn.metrics.pairwise.pairwise_distances(sen_embeds, par_embeds)
140
+ pars = pars.argsort(axis=1)[0][:topk]
141
+ filtered_texts += [texts[i] for i in pars]
142
+ filtered_texts = list(dict.fromkeys(filtered_texts))
143
+ return filtered_texts
144
+
145
+ def join_knowledge(filtered_texts):
146
+ if filtered_texts == []:
147
+ return ""
148
+ return " ".join(filtered_texts)
149
+
150
+ def retrieve_for_question_kb(question, engine, know_type="entity_know", no_links=False):
151
+ valid_ents = find_ents(question, engine)
152
+ print(valid_ents)
153
+
154
+ # find pages
155
+ page_titles = []
156
+ if "entity" in know_type:
157
+ pages_for_ents = relevant_pages_for_ents(valid_ents, topk = 5) #list of list
158
+ if pages_for_ents != []:
159
+ page_titles += pages_for_ents
160
+ if "question" in know_type:
161
+ pages_for_question = relevant_pages_for_text(question, topk = 5)
162
+ if pages_for_question != []:
163
+ page_titles += pages_for_question
164
+ pages = get_wiki_objs(page_titles) #list of list
165
+ if pages == []:
166
+ return ""
167
+ new_pages = []
168
+ assert page_titles != []
169
+ assert pages != []
170
+
171
+ print(page_titles)
172
+ #print(pages)
173
+ for i, ve_pt in enumerate(page_titles):
174
+ new_ve_pages = []
175
+ for j, pt in enumerate(ve_pt):
176
+ if 'disambiguation' in pt:
177
+ new_ve_pages += get_linked_pages([pages[i][j]], topk=-1)
178
+ else:
179
+ new_ve_pages += [pages[i][j]]
180
+ new_pages.append(new_ve_pages)
181
+
182
+ pages = new_pages
183
+
184
+ if not no_links:
185
+ # add linked pages
186
+ for ve_pages in pages:
187
+ ve_pages += get_linked_pages(ve_pages, topk=5)
188
+ ve_pages = list(dict.fromkeys(ve_pages))
189
+ #get texts
190
+ texts = get_texts_to_pages(pages, topk=1)
191
+ filtered_texts = filtering_retrieved_texts(question, texts)
192
+ joint_knowledge = join_knowledge(filtered_texts)
193
+
194
+
195
+ return valid_ents, joint_knowledge
196
+
197
+ def retrieve_for_question(question, engine, retrieve_source="llm_kb"):
198
+ # Retrieve knowledge from LLM
199
+ if "llm" in retrieve_source:
200
+ self_retrieve_prompt = "Question: " + "[ " + question + "]\n"
201
+ self_retrieve_prompt += "Necessary knowledge about the question by not answering the question: "
202
+ self_retrieve_knowledge = decoder_for_gpt3(self_retrieve_prompt, 256, engine=engine)
203
+ self_retrieve_knowledge = knowledge_cleansing(self_retrieve_knowledge)
204
+ print("------Self_Know------")
205
+ print(self_retrieve_knowledge)
206
+
207
+ # Retrieve knowledge from KB
208
+ if "kb" in retrieve_source:
209
+ entities, kb_retrieve_knowledge = retrieve_for_question_kb(question, engine, no_links=True)
210
+ if kb_retrieve_knowledge != "":
211
+ print("------KB_Know------")
212
+ print(kb_retrieve_knowledge)
213
+
214
+ return entities, self_retrieve_knowledge, kb_retrieve_knowledge
215
+
216
+ def refine_for_question(question, engine, self_retrieve_knowledge, kb_retrieve_knowledge, retrieve_source="llm_kb"):
217
+
218
+ # Refine knowledge
219
+ if retrieve_source == "llm_only":
220
+ refine_knowledge = self_retrieve_knowledge
221
+ elif retrieve_source == "kb_only":
222
+ if kb_retrieve_knowledge != "":
223
+ refine_prompt = "Question: " + "[ " + question + "]\n"
224
+ refine_prompt += "Knowledge: " + "[ " + kb_retrieve_knowledge + "]\n"
225
+ refine_prompt += "Based on Knowledge, output the brief and refined knowledge necessary for Question by not giving the answer: "
226
+ refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
227
+ print("------Refined_Know------")
228
+ print(refine_knowledge)
229
+ else:
230
+ refine_knowledge = ""
231
+ elif retrieve_source == "llm_kb":
232
+ if kb_retrieve_knowledge != "":
233
+ #refine_prompt = "Question: " + "[ " + question + "]\n"
234
+ refine_prompt = "Knowledge_1: " + "[ " + self_retrieve_knowledge + "]\n"
235
+ refine_prompt += "Knowledge_2: " + "[ " + kb_retrieve_knowledge + "]\n"
236
+ #refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge necessary for Question: "
237
+ refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge: "
238
+ refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
239
+ refine_knowledge = knowledge_cleansing(refine_knowledge)
240
+ #refine_knowledge = kb_retrieve_knowledge + refine_knowledge
241
+ print("------Refined_Know------")
242
+ print(refine_knowledge)
243
+ else:
244
+ refine_knowledge = self_retrieve_knowledge
245
+
246
+ return refine_knowledge
utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+
5
+
6
+
7
+ def answer_cleansing_zero_shot(dataset, pred, must_choice=False):
8
+ pred = pred.strip()
9
+ if dataset in ("commonsense-mc"):
10
+ pred = re.findall(r'A|B|C|D|E', pred)
11
+ elif dataset in ("arithmetic"):
12
+ if must_choice:
13
+ pred = re.findall(r'A|B|C|D', pred)
14
+ else:
15
+ pred = pred.replace(",", "")
16
+ pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
17
+ elif dataset in ("commonsense-verify", "symbolic-coin"):
18
+ pred = pred.lower()
19
+ pred = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", pred)
20
+ pred = pred.split(" ")
21
+ pred = [i for i in pred if i in ("yes", "no")]
22
+ elif dataset == "symbolic-letter":
23
+ pred = re.sub("\"|\'|\n|\.|\s", "", pred)
24
+ pred = [pred]
25
+ else:
26
+ raise ValueError("dataset is not properly defined ...")
27
+
28
+ # If there is no candidate in list, null is set.
29
+ if len(pred) == 0:
30
+ pred = ""
31
+ else:
32
+ # choose the first element in list ...
33
+ pred = pred[0]
34
+
35
+ # (For arithmetic tasks) if a word ends with period, it will be omitted ...
36
+ if pred != "":
37
+ if pred[-1] == ".":
38
+ pred = pred[:-1]
39
+
40
+ return pred
41
+
42
+ def type_cleasing(type):
43
+ type = re.findall(r'arithmetic|commonsense-mc|commonsense-verify|symbolic-coin|symbolic-letter', type)
44
+ if len(type) == 0:
45
+ type = "UNDEFINED"
46
+ else:
47
+ type = type[0]
48
+ return type
49
+
50
+
51
+ def entity_cleansing(ent):
52
+ ent = re.sub("\n|\s*-\s*|\.", ",", ent)
53
+ ent = ent.split(",")
54
+ ent = [e.strip() for e in ent if e != ""]
55
+ return ent
56
+
57
+ def knowledge_cleansing(knowledge):
58
+ #print("Knowledge Before: " + knowledge)
59
+ knowledge = knowledge.strip()
60
+ if knowledge.startswith("No, "):
61
+ knowledge = re.sub("No, ", "", knowledge)
62
+ knowledge = re.sub("\s"," ", knowledge)
63
+ #print("Knowledge After: " + knowledge)
64
+ return knowledge
65
+
66
+
67
+
68
+