Upload 4 files
Browse files- app.py +347 -45
- llm_utils.py +78 -0
- retrieval_utils.py +246 -0
- utils.py +68 -0
app.py
CHANGED
@@ -1,49 +1,351 @@
|
|
1 |
import gradio as gr
|
2 |
-
from statistics import mean
|
3 |
-
from torch.utils.data import Dataset
|
4 |
-
from collections import OrderedDict
|
5 |
-
import xml.etree.ElementTree as ET
|
6 |
import openai # For GPT-3 API ...
|
7 |
-
import os
|
8 |
-
import multiprocessing
|
9 |
-
import json
|
10 |
-
import numpy as np
|
11 |
-
import random
|
12 |
-
import torch
|
13 |
-
import torchtext
|
14 |
import re
|
15 |
-
import
|
16 |
-
import
|
17 |
-
import
|
18 |
-
|
19 |
-
import
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
import openai # For GPT-3 API ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import re
|
4 |
+
import threading
|
5 |
+
import json
|
6 |
+
from collections import Counter
|
7 |
+
from llm_utils import *
|
8 |
+
from utils import *
|
9 |
+
from retrieval_utils import *
|
10 |
+
|
11 |
+
openai.api_key = "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"
|
12 |
+
|
13 |
+
COT_PROMPT = "Let's think step by step."
|
14 |
+
DIRECT_ANS_PROMPT = "The answer is"
|
15 |
+
|
16 |
+
#EXAMPLES = {
|
17 |
+
# 'arithmetic': ['Marco and his dad went strawberry picking. Together they collected strawberries that weighed 36 pounds. On the way back Marco \' dad lost 8 pounds of strawberries. Marco\'s strawberries now weighed 12 pounds. How much did his dad\'s strawberries weigh now?'],
|
18 |
+
# 'commonsense-verify': [['is the brain located in the torso?'], ['Is entire Common Era minuscule to lifespan of some trees?'], ['Did the Football War last at least a month?']],
|
19 |
+
# 'commonsens-mc': ['What would someone use a personal key for? Answer Choices: (A) car stand (B) at hotel (C) own home (D) front door (E) bus depot', ],
|
20 |
+
# 'symbolic-letter': ['Take the last letters of each words in \"Kristopher Deb Jake Tammy\" and concatenate them.'],
|
21 |
+
# 'symbolic-coin': ['A coin is heads up. Isela flips the coin. Leslie flips the coin. Stacy flips the coin. Ingrid does not flip the coin. Is the coin still heads up? Note that \"flip\" here means \"reverse\".']
|
22 |
+
#}
|
23 |
+
|
24 |
+
EXAMPLES = ['Take the last letters of each words in \"Kristopher Deb Jake Tammy\" and concatenate them.', \
|
25 |
+
'is the brain located in the torso?', 'Is entire Common Era minuscule to lifespan of some trees?', 'Did the Football War last at least a month?', \
|
26 |
+
'What would someone use a personal key for? Answer Choices: (A) car stand (B) at hotel (C) own home (D) front door (E) bus depot', \
|
27 |
+
'A coin is heads up. Isela flips the coin. Leslie flips the coin. Stacy flips the coin. Ingrid does not flip the coin. Is the coin still heads up? Note that \"flip\" here means \"reverse\".', \
|
28 |
+
'Marco and his dad went strawberry picking. Together they collected strawberries that weighed 36 pounds. On the way back Marco \' dad lost 8 pounds of strawberries. Marco\'s strawberries now weighed 12 pounds. How much did his dad\'s strawberries weigh now?']
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
global lock #global lock, repo
|
33 |
+
lock = threading.Lock()
|
34 |
+
|
35 |
+
def answer_extraction_prompt(datatype):
|
36 |
+
if datatype == "commonsense-mc":
|
37 |
+
ans_prompt = "\nTherefore, among A through E, the answer is"
|
38 |
+
elif datatype == "commonsense-verify":
|
39 |
+
ans_prompt = "\nTherefore, the answer (Yes or No) is"
|
40 |
+
elif datatype == "arithmetic":
|
41 |
+
ans_prompt = "\nTherefore, the answer (arabic numerals) is"
|
42 |
+
elif datatype == "symbolic-letter":
|
43 |
+
ans_prompt = "\nTherefore, the answer is"
|
44 |
+
elif datatype == "symbolic-coin":
|
45 |
+
ans_prompt = "\nTherefore, the answer (Yes or No) is"
|
46 |
+
else: #if datatype == "Undefined"
|
47 |
+
ans_prompt = "\nTherefore, the answer is"
|
48 |
+
return ans_prompt
|
49 |
+
|
50 |
+
|
51 |
+
def zero_shot(datatype, question, engine):
|
52 |
+
ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype)
|
53 |
+
ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT.replace("\nTherefore, ", "")
|
54 |
+
ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT[0].upper() + ANS_EXTRACTION_PROMPT[1:]
|
55 |
+
input = "Q: " + question + "\n" + "A: " + ANS_EXTRACTION_PROMPT
|
56 |
+
ans_response = decoder_for_gpt3(input, max_length=32, engine=engine)
|
57 |
+
ans_response = answer_cleansing_zero_shot(datatype, ans_response)
|
58 |
+
if ans_response == "":
|
59 |
+
ans_response = "VOID"
|
60 |
+
return ans_response
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def highlight_knowledge(entities, retrieved_knowledge):
|
65 |
+
str_md = retrieved_knowledge
|
66 |
+
for ent in entities:
|
67 |
+
ent_md = {}
|
68 |
+
m_pos = re.finditer(ent, retrieved_knowledge, re.IGNORECASE) #[(s,e),(s,e)]
|
69 |
+
for m in m_pos:
|
70 |
+
s, e = m.start(), m.end()
|
71 |
+
if retrieved_knowledge[s:e] not in ent_md.keys():
|
72 |
+
ent_ = retrieved_knowledge[s:e]
|
73 |
+
ent_md[ent_] = '<span style="background-color: lightcoral"> **' + ent_ + '** </span>'
|
74 |
+
for e_ori, e_md in ent_md.items():
|
75 |
+
print(e_ori)
|
76 |
+
print(e_md)
|
77 |
+
str_md = str_md.replace(e_ori, e_md)
|
78 |
+
return str_md
|
79 |
+
|
80 |
+
def zero_cot_consi(question, engine):
|
81 |
+
input = "Q: " + question + "\n" + "A: " + COT_PROMPT
|
82 |
+
cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) #list of cots
|
83 |
+
return cot_responses
|
84 |
+
|
85 |
+
def auto_cot_consi(question, demo_text, engine):
|
86 |
+
input = demo_text + "Q: " + question + "\n" + "A: " + COT_PROMPT
|
87 |
+
cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) #list of cots
|
88 |
+
return cot_responses
|
89 |
+
|
90 |
+
|
91 |
+
def cot_revision(datatype, question, ori_cots, knowledge, engine):
|
92 |
+
ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype)
|
93 |
+
corrected_rationales = []
|
94 |
+
corrected_answers = []
|
95 |
+
correction_prompt = "Question: " + "[ " + question + "]\n"
|
96 |
+
correction_prompt += "Knowledge: " + "[ " + knowledge + "]\n"
|
97 |
+
for ori_r in ori_cots:
|
98 |
+
cor_p = correction_prompt + "Original rationale: " + "[ " + ori_r + "]\n"
|
99 |
+
cor_p += "With Knowledge given, output the revised rationale for Question in a precise and certain style by thinking step by step: "
|
100 |
+
corrected_rationale = decoder_for_gpt3(cor_p,max_length=256, temperature=0.7, engine=engine)
|
101 |
+
corrected_rationale = corrected_rationale.strip()
|
102 |
+
corrected_rationales.append(corrected_rationale)
|
103 |
+
input = "Q: " + question + "\n" + "A: " + corrected_rationale + ANS_EXTRACTION_PROMPT
|
104 |
+
ans = decoder_for_gpt3(input, max_length=32, temperature=0.7, engine=engine)
|
105 |
+
ans = answer_cleansing_zero_shot(datatype, ans)
|
106 |
+
corrected_answers.append(ans)
|
107 |
+
return corrected_rationales, corrected_answers
|
108 |
+
|
109 |
+
|
110 |
+
def consistency(arr):
|
111 |
+
len_ans = len(arr)
|
112 |
+
arr_acounts = Counter(arr)
|
113 |
+
ans_freq_tuple = arr_acounts.most_common(len_ans)
|
114 |
+
most_frequent_item, _ = ans_freq_tuple[0]
|
115 |
+
ans_dict = {}
|
116 |
+
for ans_freq in ans_freq_tuple:
|
117 |
+
ans, times = ans_freq
|
118 |
+
ans_dict[ans] = times/len_ans
|
119 |
+
return most_frequent_item, ans_dict
|
120 |
+
|
121 |
+
|
122 |
+
## todo: git pull
|
123 |
+
def record_feedback(single_data, feedback, store_flag):
|
124 |
+
global lock
|
125 |
+
print(f"Logging feedback...")
|
126 |
+
datatype = single_data['datatype']
|
127 |
+
data_dir = './data_pool/{dataname}_feedback'.format(dataname=datatype)
|
128 |
+
|
129 |
+
lock.acquire()
|
130 |
+
if store_flag:
|
131 |
+
single_data.update({'feedback':feedback})
|
132 |
+
with open(data_dir, "a") as f:
|
133 |
+
data_json = json.dumps(single_data)
|
134 |
+
f.write(data_json + "\n")
|
135 |
+
lock.release()
|
136 |
+
print(f"Logging finished...")
|
137 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
138 |
+
gr.update(value="😃 Thank you for your valuable feedback!")
|
139 |
+
|
140 |
+
|
141 |
+
def record_feedback_agree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
|
142 |
+
single_data = {
|
143 |
+
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
|
144 |
+
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
|
145 |
+
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
|
146 |
+
return record_feedback(single_data, 'agree', store_flag)
|
147 |
+
def record_feedback_disagree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
|
148 |
+
single_data = {
|
149 |
+
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
|
150 |
+
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
|
151 |
+
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
|
152 |
+
return record_feedback(single_data, "disagree", store_flag)
|
153 |
+
def record_feedback_uncertain(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag):
|
154 |
+
single_data = {
|
155 |
+
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans,
|
156 |
+
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know,
|
157 |
+
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""}
|
158 |
+
return record_feedback(single_data, 'uncertain', store_flag)
|
159 |
+
|
160 |
+
def reset():
|
161 |
+
return gr.update(value=""), gr.update(value=""), \
|
162 |
+
gr.update(visible=False), gr.update(value="", label=""), gr.update(value="", label=""), gr.update(value="", label=""), \
|
163 |
+
gr.update(value=""), gr.update(value=""), gr.update(value=""), gr.update(value="")
|
164 |
+
|
165 |
+
|
166 |
+
def identify_type(question, engine):
|
167 |
+
with open('./demos/type', 'r') as f:
|
168 |
+
typedemo = f.read()
|
169 |
+
typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: "
|
170 |
+
response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine)
|
171 |
+
response = response.strip().lower()
|
172 |
+
response = type_cleasing(response)
|
173 |
+
return response
|
174 |
+
|
175 |
+
def load_examples(datatype):
|
176 |
+
return gr.update(examples=EXAMPLES[datatype])
|
177 |
+
|
178 |
+
|
179 |
+
def self_construction(datatype):
|
180 |
+
if datatype == "arithmetic":
|
181 |
+
fig_adr = './figs/multiarith.png'
|
182 |
+
demo_path = './demos/multiarith'
|
183 |
+
elif datatype == "commonsense-mc":
|
184 |
+
fig_adr = './figs/commonsensqa.png'
|
185 |
+
demo_path = './demos/commonsensqa'
|
186 |
+
elif datatype == "commonsense-verify":
|
187 |
+
fig_adr = './figs/strategyqa.png'
|
188 |
+
demo_path = './demos/strategyqa'
|
189 |
+
elif datatype == "symbolic-coin":
|
190 |
+
fig_adr = './figs/coin_flip.png'
|
191 |
+
demo_path = './demos/coin_flip'
|
192 |
+
elif datatype == "symbolic-letter":
|
193 |
+
fig_adr = './figs/last_letters.png'
|
194 |
+
demo_path = './demos/last_letters'
|
195 |
+
else:
|
196 |
+
pass ##todo: datatype == 'UNDEFINED'
|
197 |
+
|
198 |
+
##读取对应的demo
|
199 |
+
x, z, y =[], [], []
|
200 |
+
with open(demo_path, encoding="utf-8") as f:
|
201 |
+
json_data = json.load(f)
|
202 |
+
json_data = json_data["demo"]
|
203 |
+
for line in json_data:
|
204 |
+
x.append(line["question"])
|
205 |
+
z.append(line["rationale"])
|
206 |
+
y.append(line["pred_ans"])
|
207 |
+
index_list = list(range(len(x)))
|
208 |
+
|
209 |
+
demo_md, demo_text = "", ""
|
210 |
+
for i in index_list:
|
211 |
+
demo_text += x[i] + " " + z[i] + " " + \
|
212 |
+
DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n"
|
213 |
+
demo_md += '<span style="background-color: #E0A182">' + "Q: "+ '</span>' + x[i][3:-3] + \
|
214 |
+
"<br>" + '<span style="background-color: #DD97AF">' + "A: "+ '</span>' + z[i] + " " + \
|
215 |
+
DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n"
|
216 |
+
|
217 |
+
|
218 |
+
return gr.update(value="## 🔭 Self construction..."), gr.update(visible=True, label="Visualization of clustering", value=fig_adr), \
|
219 |
+
gr.update(visible=True, value=demo_md), gr.update(value=demo_text), \
|
220 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
221 |
+
|
222 |
+
def self_retrieval(input_question, engine):
|
223 |
+
entities, self_retrieve_knowledge, kb_retrieve_knowledge = retrieve_for_question(input_question, engine)
|
224 |
+
|
225 |
+
entities_string = ", ".join(entities)
|
226 |
+
retr_md = "### ENTITIES:" + "<br>" + "> "+ entities_string + "\n\n"
|
227 |
+
retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities,self_retrieve_knowledge) + "\n\n"
|
228 |
+
retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n"
|
229 |
+
|
230 |
+
return gr.update(value="## 📚 Self retrieval..."), gr.update(visible=True, label="", value='./figs/self-retrieval.png'), \
|
231 |
+
gr.update(value=retr_md), \
|
232 |
+
gr.update(value=entities_string), gr.update(value=self_retrieve_knowledge), gr.update(value=kb_retrieve_knowledge), \
|
233 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
234 |
+
|
235 |
+
def self_refinement(input_question, entities, self_retrieve_knowledge, kb_retrieve_knowledge, engine):
|
236 |
+
refine_knowledge = refine_for_question(input_question, engine, self_retrieve_knowledge, kb_retrieve_knowledge)
|
237 |
+
|
238 |
+
retr_md = "### ENTITIES:" + "<br>" + "> " + entities + "\n\n"
|
239 |
+
entities = entities.strip().strip('<p>').strip('</p>').split(", ")
|
240 |
+
retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, self_retrieve_knowledge) + "\n\n"
|
241 |
+
retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n"
|
242 |
+
refine_md = retr_md + "### REFINED-KNOWLEDGE:" + "<br>" + "> "
|
243 |
+
refine_md += highlight_knowledge(entities, refine_knowledge)
|
244 |
+
|
245 |
+
|
246 |
+
return gr.update(value="## 🪄 Self refinement..."), gr.update(visible=True, label="", value='./figs/self-refinement.png'), \
|
247 |
+
gr.update(value=refine_md), gr.update(value=refine_knowledge), \
|
248 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
249 |
+
|
250 |
+
def self_revision(input_question, datatype, demo_text, refined_knowledge, engine):
|
251 |
+
print(demo_text)
|
252 |
+
print(refined_knowledge)
|
253 |
+
ori_cots = auto_cot_consi(input_question, demo_text, engine)
|
254 |
+
cor_cots, cor_ans = cot_revision(datatype, input_question, ori_cots, refined_knowledge, engine)
|
255 |
+
cor_cots_md = "### Revised Rationales:" + "\n\n"
|
256 |
+
for cor_cot in cor_cots:
|
257 |
+
cor_cots_md += "> " + cor_cot + "\n\n"
|
258 |
+
cor_ans = ", ".join(cor_ans)
|
259 |
+
|
260 |
+
return gr.update(value="## 🔧 Self revision..."), gr.update(visible=True, label="", value='./figs/self-revision.png'), \
|
261 |
+
gr.update(value=cor_cots_md), gr.update(value=cor_ans), \
|
262 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
263 |
+
|
264 |
+
def self_consistency(cor_ans, datatype, question, engine):
|
265 |
+
cor_ans = cor_ans.strip().split(", ")
|
266 |
+
our_ans, ans_dict = consistency(cor_ans)
|
267 |
+
zeroshot_ans = zero_shot(datatype, question, engine)
|
268 |
+
|
269 |
+
return gr.update(value="## 🗳 Self consistency..."), gr.update(visible=True, label="", value='./figs/self-consistency.png'), \
|
270 |
+
gr.update(value=""), gr.update(value=ans_dict, visible=True), \
|
271 |
+
gr.update(visible=True, value=our_ans), gr.update(visible=True, value=zeroshot_ans), \
|
272 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
273 |
+
gr.update(visible=True, value='We would appreciate it very much if you could share your feedback. ')
|
274 |
+
|
275 |
+
|
276 |
+
def reset():
|
277 |
+
return gr.update(value=""), gr.update(value=""), gr.update(value=""), \
|
278 |
+
gr.update(visible=False), gr.update(value=""), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
|
279 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="")
|
280 |
+
|
281 |
+
#theme from: https://huggingface.co/spaces/gradio/theme-gallery
|
282 |
+
#EveryPizza/Cartoony-Gradio-Theme
|
283 |
+
#JohnSmith9982/small_and_pretty
|
284 |
+
#bethecloud/storj_theme
|
285 |
+
#gradio/soft
|
286 |
+
with gr.Blocks(theme="bethecloud/storj_theme", css="#process_btn {background-color:#8BA3C5}") as demo:
|
287 |
+
gr.Markdown("# 🌟 通用自适应的推理增强系统 (Unified-Adapter) 🌟")
|
288 |
+
with gr.Row():
|
289 |
+
with gr.Column(scale=4):
|
290 |
+
input_question = gr.Textbox(placeholder="Input question here, or select an example from below.", label="Input Question",lines=2)
|
291 |
+
store_flag = gr.Checkbox(label="Store data",value=True, interactive=True, info="If you agree to store data for research and development use:")
|
292 |
+
single_data = gr.JSON(visible=False)
|
293 |
+
with gr.Column(scale=3):
|
294 |
+
engine = gr.Dropdown(choices=['gpt-3.5-turbo','text-davinci-003', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001'],
|
295 |
+
label="Engine", value="text-davinci-003", interactive=True, info="Choose the engine and have a try!")
|
296 |
+
reset_btn = gr.Button(value='RESET')
|
297 |
+
examples = gr.Examples(examples=EXAMPLES, inputs=[input_question])
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
with gr.Column(scale=1):
|
301 |
+
type_btn = gr.Button(value="Self-identification", variant='primary', scale=1, elem_id="process_btn")
|
302 |
+
with gr.Column(scale=3):
|
303 |
+
datatype = gr.Dropdown(choices=['arithmetic','commonsense-mc','commonsense-verify','symbolic-letter','symbolic-coin','UNDEFINED'],
|
304 |
+
label="Input Type", info="If you disagree with our output, please select manually.", scale=3)
|
305 |
+
|
306 |
+
demo_text = gr.Textbox(visible=False)
|
307 |
+
entities = gr.Textbox(visible=False)
|
308 |
+
self_know = gr.Textbox(visible=False)
|
309 |
+
kb_know = gr.Textbox(visible=False)
|
310 |
+
refine_know = gr.Textbox(visible=False)
|
311 |
+
cor_ans = gr.Textbox(visible=False)
|
312 |
+
with gr.Row():
|
313 |
+
const_btn = gr.Button(value='Self-construction', variant='primary', elem_id="process_btn")
|
314 |
+
retr_btn = gr.Button(value='Self-retrieval', variant='primary', elem_id="process_btn")
|
315 |
+
refine_btn = gr.Button(value='Self-refinement', variant='primary', elem_id="process_btn")
|
316 |
+
revis_btn = gr.Button(value='Self-revision', variant='primary', elem_id="process_btn")
|
317 |
+
consis_btn = gr.Button(value='Self-consistency', variant='primary', elem_id="process_btn")
|
318 |
+
|
319 |
+
sub_title = gr.Markdown()
|
320 |
+
with gr.Row():
|
321 |
+
with gr.Column(scale=2):
|
322 |
+
plot = gr.Image(label="Visualization of clustering", visible=False)
|
323 |
+
with gr.Column(scale=3):
|
324 |
+
md = gr.Markdown()
|
325 |
+
label = gr.Label(visible=False, label="Consistency Predictions")
|
326 |
+
ans_ours = gr.Textbox(label="Unified-Adapter Answer",visible=False)
|
327 |
+
ans_zeroshot = gr.Textbox(label="Zero-shot Answer", visible=False)
|
328 |
+
with gr.Row():
|
329 |
+
feedback_agree = gr.Button(value='😊 Agree', variant='secondary', visible=False)
|
330 |
+
feedback_disagree = gr.Button(value='🙁 Disagree', variant='secondary', visible=False)
|
331 |
+
feedback_uncertain = gr.Button(value='🤔 Uncertain', variant='secondary', visible=False)
|
332 |
+
feedback_ack = gr.Markdown(value='', visible=True, interactive=False)
|
333 |
+
|
334 |
+
|
335 |
+
type_btn.click(identify_type, inputs=[input_question, engine], outputs=[datatype])
|
336 |
+
const_btn.click(self_construction, inputs=[datatype], outputs=[sub_title, plot, md, demo_text, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
337 |
+
retr_btn.click(self_retrieval, inputs=[input_question, engine], outputs=[sub_title, plot, md, entities, self_know, kb_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
338 |
+
refine_btn.click(self_refinement, inputs=[input_question, entities, self_know, kb_know, engine], outputs=[sub_title, plot, md, refine_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
339 |
+
revis_btn.click(self_revision, inputs=[input_question, datatype, demo_text, refine_know, engine], outputs=[sub_title, plot, md, cor_ans, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
340 |
+
consis_btn.click(self_consistency, inputs=[cor_ans, datatype, input_question, engine], outputs=[sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
341 |
+
reset_btn.click(reset, inputs=[], outputs=[input_question, datatype, sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
342 |
+
|
343 |
+
feedback_agree.click(record_feedback_agree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
344 |
+
feedback_disagree.click(record_feedback_disagree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
345 |
+
feedback_uncertain.click(record_feedback_uncertain, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack])
|
346 |
+
|
347 |
+
|
348 |
+
demo.launch()
|
349 |
+
|
350 |
+
|
351 |
|
llm_utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import openai
|
3 |
+
|
4 |
+
#openai.api_key = "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"
|
5 |
+
|
6 |
+
|
7 |
+
# Sentence Generator (Decoder) for GPT-3 ...
|
8 |
+
def decoder_for_gpt3(input, max_length, temperature=0, engine="text-davinci-003"):
|
9 |
+
# GPT-3 API allows each users execute the API within 60 times in a minute ...
|
10 |
+
if engine == "gpt-3.5-turbo":
|
11 |
+
time.sleep(1)
|
12 |
+
response = openai.ChatCompletion.create(
|
13 |
+
model=engine,
|
14 |
+
messages=[
|
15 |
+
#{"role": "system", "content": "You need to answer commonsense questions."},
|
16 |
+
{"role": "user", "content": input}
|
17 |
+
],
|
18 |
+
max_tokens=max_length,
|
19 |
+
temperature=temperature,
|
20 |
+
stop=None
|
21 |
+
)
|
22 |
+
response = response["choices"][0]["message"]["content"]
|
23 |
+
|
24 |
+
else:
|
25 |
+
time.sleep(1)
|
26 |
+
response = openai.Completion.create(
|
27 |
+
model=engine,
|
28 |
+
prompt=input,
|
29 |
+
max_tokens=max_length,
|
30 |
+
stop=None,
|
31 |
+
temperature=temperature
|
32 |
+
)
|
33 |
+
response = response["choices"][0]["text"]
|
34 |
+
return response
|
35 |
+
|
36 |
+
def decoder_for_gpt3_consistency(input, max_length, temp=0.7, n=5, engine="text-davinci-003"):
|
37 |
+
# GPT-3 API allows each users execute the API within 60 times in a minute ...
|
38 |
+
if engine == "gpt-3.5-turbo":
|
39 |
+
time.sleep(1)
|
40 |
+
responses = openai.ChatCompletion.create(
|
41 |
+
model=engine,
|
42 |
+
messages=[
|
43 |
+
{"role": "user", "content": input}
|
44 |
+
],
|
45 |
+
max_tokens=max_length,
|
46 |
+
temperature=temp,
|
47 |
+
top_p=1,
|
48 |
+
n=5,
|
49 |
+
stop=["\n"],
|
50 |
+
)
|
51 |
+
responses = [responses["choices"][i]["message"]["content"] for i in range(n)]
|
52 |
+
else:
|
53 |
+
time.sleep(1)
|
54 |
+
responses = openai.Completion.create(
|
55 |
+
model=engine,
|
56 |
+
prompt=input,
|
57 |
+
max_tokens=max_length,
|
58 |
+
temperature=temp,
|
59 |
+
stop=["\n"],
|
60 |
+
n=5,
|
61 |
+
logprobs=5,
|
62 |
+
top_p=1,
|
63 |
+
)
|
64 |
+
responses = [responses["choices"][i]["text"] for i in range(n)]
|
65 |
+
|
66 |
+
return responses
|
67 |
+
|
68 |
+
def zero_shot(question):
|
69 |
+
input = question + " " + "Among A through E, the answer is"
|
70 |
+
response = openai.ChatCompletion.create(
|
71 |
+
model="gpt-3.5-turbo",
|
72 |
+
messages=[
|
73 |
+
{"role": "system", "content": "You are a helpful assistant that answer commonsense questions."},
|
74 |
+
{"role": "user", "content": input}
|
75 |
+
]
|
76 |
+
)
|
77 |
+
response = response["choices"][0]["message"]["content"]
|
78 |
+
return response
|
retrieval_utils.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Modified from https://github.com/RuochenZhao/Verify-and-Edit
|
3 |
+
'''
|
4 |
+
|
5 |
+
import wikipedia
|
6 |
+
import wikipediaapi
|
7 |
+
import spacy
|
8 |
+
import numpy as np
|
9 |
+
import ngram
|
10 |
+
#import nltk
|
11 |
+
import torch
|
12 |
+
import sklearn
|
13 |
+
#from textblob import TextBlob
|
14 |
+
from nltk import tokenize
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer
|
17 |
+
from llm_utils import decoder_for_gpt3
|
18 |
+
from utils import entity_cleansing, knowledge_cleansing
|
19 |
+
|
20 |
+
wiki_wiki = wikipediaapi.Wikipedia('en')
|
21 |
+
nlp = spacy.load("en_core_web_sm")
|
22 |
+
ENT_TYPE = ['EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'NORP', 'ORG', 'PERSON', 'PRODUCT', 'WORK_OF_ART']
|
23 |
+
|
24 |
+
CTX_ENCODER = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
25 |
+
CTX_TOKENIZER = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", model_max_length = 512)
|
26 |
+
Q_ENCODER = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
27 |
+
Q_TOKENIZER = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base", model_max_length = 512)
|
28 |
+
|
29 |
+
|
30 |
+
## todo: extract entities from ConceptNet
|
31 |
+
def find_ents(text, engine):
|
32 |
+
doc = nlp(text)
|
33 |
+
valid_ents = []
|
34 |
+
for ent in doc.ents:
|
35 |
+
if ent.label_ in ENT_TYPE:
|
36 |
+
valid_ents.append(ent.text)
|
37 |
+
#in case entity list is empty: resort to LLM to extract entity
|
38 |
+
if valid_ents == []:
|
39 |
+
input = "Question: " + "[ " + text + "]\n"
|
40 |
+
input += "Output the entities in Question separated by comma: "
|
41 |
+
response = decoder_for_gpt3(input, 32, engine=engine)
|
42 |
+
valid_ents = entity_cleansing(response)
|
43 |
+
return valid_ents
|
44 |
+
|
45 |
+
|
46 |
+
def relevant_pages_for_ents(valid_ents, topk = 5):
|
47 |
+
'''
|
48 |
+
Input: a list of valid entities
|
49 |
+
Output: a list of list containing topk pages for each entity
|
50 |
+
'''
|
51 |
+
if valid_ents == []:
|
52 |
+
return []
|
53 |
+
titles = []
|
54 |
+
for ve in valid_ents:
|
55 |
+
title = wikipedia.search(ve)[:topk]
|
56 |
+
titles.append(title)
|
57 |
+
#titles = list(dict.fromkeys(titles))
|
58 |
+
return titles
|
59 |
+
|
60 |
+
|
61 |
+
def relevant_pages_for_text(text, topk = 5):
|
62 |
+
return wikipedia.search(text)[:topk]
|
63 |
+
|
64 |
+
|
65 |
+
def get_wiki_objs(pages):
|
66 |
+
'''
|
67 |
+
Input: a list of list
|
68 |
+
Output: a list of list
|
69 |
+
'''
|
70 |
+
if pages == []:
|
71 |
+
return []
|
72 |
+
obj_pages = []
|
73 |
+
for titles_for_ve in pages:
|
74 |
+
pages_for_ve = [wiki_wiki.page(title) for title in titles_for_ve]
|
75 |
+
obj_pages.append(pages_for_ve)
|
76 |
+
return obj_pages
|
77 |
+
|
78 |
+
|
79 |
+
def get_linked_pages(wiki_pages, topk = 5):
|
80 |
+
linked_ents = []
|
81 |
+
for wp in wiki_pages:
|
82 |
+
linked_ents += list(wp.links.values())
|
83 |
+
if topk != -1:
|
84 |
+
linked_ents = linked_ents[:topk]
|
85 |
+
return linked_ents
|
86 |
+
|
87 |
+
|
88 |
+
def get_texts_to_pages(pages, topk = 2):
|
89 |
+
'''
|
90 |
+
Input: list of list of pages
|
91 |
+
Output: list of list of texts
|
92 |
+
'''
|
93 |
+
total_texts = []
|
94 |
+
for ve_pages in pages:
|
95 |
+
ve_texts = []
|
96 |
+
for p in ve_pages:
|
97 |
+
text = p.text
|
98 |
+
text = tokenize.sent_tokenize(text)[:topk]
|
99 |
+
text = ' '.join(text)
|
100 |
+
ve_texts.append(text)
|
101 |
+
total_texts.append(ve_texts)
|
102 |
+
return total_texts
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def DPR_embeddings(q_encoder, q_tokenizer, question):
|
107 |
+
question_embedding = q_tokenizer(question, return_tensors="pt",max_length=5, truncation=True)
|
108 |
+
with torch.no_grad():
|
109 |
+
try:
|
110 |
+
question_embedding = q_encoder(**question_embedding)[0][0]
|
111 |
+
except:
|
112 |
+
print(question)
|
113 |
+
print(question_embedding['input_ids'].size())
|
114 |
+
raise Exception('end')
|
115 |
+
question_embedding = question_embedding.numpy()
|
116 |
+
return question_embedding
|
117 |
+
|
118 |
+
def model_embeddings(sentence, model):
|
119 |
+
embedding = model.encode([sentence])
|
120 |
+
return embedding[0] #should return an array of shape 384
|
121 |
+
|
122 |
+
##todo: plus overlap filtering
|
123 |
+
def filtering_retrieved_texts(question, ent_texts, retr_method="wikipedia_dpr", topk=1):
|
124 |
+
filtered_texts = []
|
125 |
+
for texts in ent_texts:
|
126 |
+
if texts != []: #not empty list
|
127 |
+
if retr_method == "ngram":
|
128 |
+
pars = np.array([ngram.NGram.compare(question, sent, N=1) for sent in texts])
|
129 |
+
#argsort: smallest to biggest
|
130 |
+
pars = pars.argsort()[::-1][:topk]
|
131 |
+
else:
|
132 |
+
if retr_method == "wikipedia_dpr":
|
133 |
+
sen_embeds = [DPR_embeddings(Q_ENCODER, Q_TOKENIZER, question)]
|
134 |
+
par_embeds = [DPR_embeddings(CTX_ENCODER, CTX_TOKENIZER, s) for s in texts]
|
135 |
+
else:
|
136 |
+
embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
|
137 |
+
sen_embeds = [model_embeddings(question, embedding_model)]
|
138 |
+
par_embeds = [model_embeddings(s, embedding_model) for s in texts]
|
139 |
+
pars = sklearn.metrics.pairwise.pairwise_distances(sen_embeds, par_embeds)
|
140 |
+
pars = pars.argsort(axis=1)[0][:topk]
|
141 |
+
filtered_texts += [texts[i] for i in pars]
|
142 |
+
filtered_texts = list(dict.fromkeys(filtered_texts))
|
143 |
+
return filtered_texts
|
144 |
+
|
145 |
+
def join_knowledge(filtered_texts):
|
146 |
+
if filtered_texts == []:
|
147 |
+
return ""
|
148 |
+
return " ".join(filtered_texts)
|
149 |
+
|
150 |
+
def retrieve_for_question_kb(question, engine, know_type="entity_know", no_links=False):
|
151 |
+
valid_ents = find_ents(question, engine)
|
152 |
+
print(valid_ents)
|
153 |
+
|
154 |
+
# find pages
|
155 |
+
page_titles = []
|
156 |
+
if "entity" in know_type:
|
157 |
+
pages_for_ents = relevant_pages_for_ents(valid_ents, topk = 5) #list of list
|
158 |
+
if pages_for_ents != []:
|
159 |
+
page_titles += pages_for_ents
|
160 |
+
if "question" in know_type:
|
161 |
+
pages_for_question = relevant_pages_for_text(question, topk = 5)
|
162 |
+
if pages_for_question != []:
|
163 |
+
page_titles += pages_for_question
|
164 |
+
pages = get_wiki_objs(page_titles) #list of list
|
165 |
+
if pages == []:
|
166 |
+
return ""
|
167 |
+
new_pages = []
|
168 |
+
assert page_titles != []
|
169 |
+
assert pages != []
|
170 |
+
|
171 |
+
print(page_titles)
|
172 |
+
#print(pages)
|
173 |
+
for i, ve_pt in enumerate(page_titles):
|
174 |
+
new_ve_pages = []
|
175 |
+
for j, pt in enumerate(ve_pt):
|
176 |
+
if 'disambiguation' in pt:
|
177 |
+
new_ve_pages += get_linked_pages([pages[i][j]], topk=-1)
|
178 |
+
else:
|
179 |
+
new_ve_pages += [pages[i][j]]
|
180 |
+
new_pages.append(new_ve_pages)
|
181 |
+
|
182 |
+
pages = new_pages
|
183 |
+
|
184 |
+
if not no_links:
|
185 |
+
# add linked pages
|
186 |
+
for ve_pages in pages:
|
187 |
+
ve_pages += get_linked_pages(ve_pages, topk=5)
|
188 |
+
ve_pages = list(dict.fromkeys(ve_pages))
|
189 |
+
#get texts
|
190 |
+
texts = get_texts_to_pages(pages, topk=1)
|
191 |
+
filtered_texts = filtering_retrieved_texts(question, texts)
|
192 |
+
joint_knowledge = join_knowledge(filtered_texts)
|
193 |
+
|
194 |
+
|
195 |
+
return valid_ents, joint_knowledge
|
196 |
+
|
197 |
+
def retrieve_for_question(question, engine, retrieve_source="llm_kb"):
|
198 |
+
# Retrieve knowledge from LLM
|
199 |
+
if "llm" in retrieve_source:
|
200 |
+
self_retrieve_prompt = "Question: " + "[ " + question + "]\n"
|
201 |
+
self_retrieve_prompt += "Necessary knowledge about the question by not answering the question: "
|
202 |
+
self_retrieve_knowledge = decoder_for_gpt3(self_retrieve_prompt, 256, engine=engine)
|
203 |
+
self_retrieve_knowledge = knowledge_cleansing(self_retrieve_knowledge)
|
204 |
+
print("------Self_Know------")
|
205 |
+
print(self_retrieve_knowledge)
|
206 |
+
|
207 |
+
# Retrieve knowledge from KB
|
208 |
+
if "kb" in retrieve_source:
|
209 |
+
entities, kb_retrieve_knowledge = retrieve_for_question_kb(question, engine, no_links=True)
|
210 |
+
if kb_retrieve_knowledge != "":
|
211 |
+
print("------KB_Know------")
|
212 |
+
print(kb_retrieve_knowledge)
|
213 |
+
|
214 |
+
return entities, self_retrieve_knowledge, kb_retrieve_knowledge
|
215 |
+
|
216 |
+
def refine_for_question(question, engine, self_retrieve_knowledge, kb_retrieve_knowledge, retrieve_source="llm_kb"):
|
217 |
+
|
218 |
+
# Refine knowledge
|
219 |
+
if retrieve_source == "llm_only":
|
220 |
+
refine_knowledge = self_retrieve_knowledge
|
221 |
+
elif retrieve_source == "kb_only":
|
222 |
+
if kb_retrieve_knowledge != "":
|
223 |
+
refine_prompt = "Question: " + "[ " + question + "]\n"
|
224 |
+
refine_prompt += "Knowledge: " + "[ " + kb_retrieve_knowledge + "]\n"
|
225 |
+
refine_prompt += "Based on Knowledge, output the brief and refined knowledge necessary for Question by not giving the answer: "
|
226 |
+
refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
|
227 |
+
print("------Refined_Know------")
|
228 |
+
print(refine_knowledge)
|
229 |
+
else:
|
230 |
+
refine_knowledge = ""
|
231 |
+
elif retrieve_source == "llm_kb":
|
232 |
+
if kb_retrieve_knowledge != "":
|
233 |
+
#refine_prompt = "Question: " + "[ " + question + "]\n"
|
234 |
+
refine_prompt = "Knowledge_1: " + "[ " + self_retrieve_knowledge + "]\n"
|
235 |
+
refine_prompt += "Knowledge_2: " + "[ " + kb_retrieve_knowledge + "]\n"
|
236 |
+
#refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge necessary for Question: "
|
237 |
+
refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge: "
|
238 |
+
refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
|
239 |
+
refine_knowledge = knowledge_cleansing(refine_knowledge)
|
240 |
+
#refine_knowledge = kb_retrieve_knowledge + refine_knowledge
|
241 |
+
print("------Refined_Know------")
|
242 |
+
print(refine_knowledge)
|
243 |
+
else:
|
244 |
+
refine_knowledge = self_retrieve_knowledge
|
245 |
+
|
246 |
+
return refine_knowledge
|
utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def answer_cleansing_zero_shot(dataset, pred, must_choice=False):
|
8 |
+
pred = pred.strip()
|
9 |
+
if dataset in ("commonsense-mc"):
|
10 |
+
pred = re.findall(r'A|B|C|D|E', pred)
|
11 |
+
elif dataset in ("arithmetic"):
|
12 |
+
if must_choice:
|
13 |
+
pred = re.findall(r'A|B|C|D', pred)
|
14 |
+
else:
|
15 |
+
pred = pred.replace(",", "")
|
16 |
+
pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
|
17 |
+
elif dataset in ("commonsense-verify", "symbolic-coin"):
|
18 |
+
pred = pred.lower()
|
19 |
+
pred = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", pred)
|
20 |
+
pred = pred.split(" ")
|
21 |
+
pred = [i for i in pred if i in ("yes", "no")]
|
22 |
+
elif dataset == "symbolic-letter":
|
23 |
+
pred = re.sub("\"|\'|\n|\.|\s", "", pred)
|
24 |
+
pred = [pred]
|
25 |
+
else:
|
26 |
+
raise ValueError("dataset is not properly defined ...")
|
27 |
+
|
28 |
+
# If there is no candidate in list, null is set.
|
29 |
+
if len(pred) == 0:
|
30 |
+
pred = ""
|
31 |
+
else:
|
32 |
+
# choose the first element in list ...
|
33 |
+
pred = pred[0]
|
34 |
+
|
35 |
+
# (For arithmetic tasks) if a word ends with period, it will be omitted ...
|
36 |
+
if pred != "":
|
37 |
+
if pred[-1] == ".":
|
38 |
+
pred = pred[:-1]
|
39 |
+
|
40 |
+
return pred
|
41 |
+
|
42 |
+
def type_cleasing(type):
|
43 |
+
type = re.findall(r'arithmetic|commonsense-mc|commonsense-verify|symbolic-coin|symbolic-letter', type)
|
44 |
+
if len(type) == 0:
|
45 |
+
type = "UNDEFINED"
|
46 |
+
else:
|
47 |
+
type = type[0]
|
48 |
+
return type
|
49 |
+
|
50 |
+
|
51 |
+
def entity_cleansing(ent):
|
52 |
+
ent = re.sub("\n|\s*-\s*|\.", ",", ent)
|
53 |
+
ent = ent.split(",")
|
54 |
+
ent = [e.strip() for e in ent if e != ""]
|
55 |
+
return ent
|
56 |
+
|
57 |
+
def knowledge_cleansing(knowledge):
|
58 |
+
#print("Knowledge Before: " + knowledge)
|
59 |
+
knowledge = knowledge.strip()
|
60 |
+
if knowledge.startswith("No, "):
|
61 |
+
knowledge = re.sub("No, ", "", knowledge)
|
62 |
+
knowledge = re.sub("\s"," ", knowledge)
|
63 |
+
#print("Knowledge After: " + knowledge)
|
64 |
+
return knowledge
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|