Zihao Wang commited on
Commit
3119d85
1 Parent(s): ba30d19
Files changed (1) hide show
  1. app.py +416 -13
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import gradio as gr
2
  from langchain.tools import Tool
3
  from langchain_community.utilities import GoogleSearchAPIWrapper
4
- import os
 
 
 
5
 
6
- def get_search(query:str="", k:int=1):
 
7
  search = GoogleSearchAPIWrapper(k=k)
8
  def search_results(query):
9
  return search.results(query, k)
@@ -13,17 +17,416 @@ def get_search(query:str="", k:int=1):
13
  func=search_results,
14
  )
15
  ref_text = tool.run(query)
16
- return ref_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def search(query:str):
19
- search_result = get_search(query,1)[0]
20
- title = search_result['title']
21
- link = search_result['link']
22
- return_str = f"""title: {title}\nlink: {link}"""
23
- print(return_str)
24
- return return_str
 
 
 
 
25
 
26
- demo = gr.Interface(fn=search, inputs="textbox", outputs="textbox")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- if __name__ == "__main__":
29
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from langchain.tools import Tool
3
  from langchain_community.utilities import GoogleSearchAPIWrapper
4
+ import os
5
+
6
+ from langchain.tools import Tool
7
+ from langchain_community.utilities import GoogleSearchAPIWrapper
8
 
9
+
10
+ def get_search(query:str="", k:int=1): # get the top-k resources with google
11
  search = GoogleSearchAPIWrapper(k=k)
12
  def search_results(query):
13
  return search.results(query, k)
 
17
  func=search_results,
18
  )
19
  ref_text = tool.run(query)
20
+ if 'Result' not in ref_text[0].keys():
21
+ return ref_text
22
+ else:
23
+ return None
24
+
25
+ from langchain_community.document_transformers import Html2TextTransformer
26
+ from langchain_community.document_loaders import AsyncHtmlLoader
27
+ def get_page_content(link:str):
28
+ loader = AsyncHtmlLoader([link])
29
+ docs = loader.load()
30
+ html2text = Html2TextTransformer()
31
+ docs_transformed = html2text.transform_documents(docs)
32
+ if len(docs_transformed) > 0:
33
+ return docs_transformed[0].page_content
34
+ else:
35
+ return None
36
+
37
+ import tiktoken
38
+ def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
39
+ """Returns the number of tokens in a text string."""
40
+ encoding = tiktoken.get_encoding(encoding_name)
41
+ num_tokens = len(encoding.encode(string))
42
+ return num_tokens
43
+
44
+ def chunk_text_by_sentence(text, chunk_size=2048):
45
+ """Chunk the $text into sentences with less than 2k tokens."""
46
+ sentences = text.split('. ')
47
+ chunked_text = []
48
+ curr_chunk = []
49
+ # 逐句添加文本片段,确保每个段落都小于2k个token
50
+ for sentence in sentences:
51
+ if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size:
52
+ curr_chunk.append(sentence)
53
+ else:
54
+ chunked_text.append(". ".join(curr_chunk))
55
+ curr_chunk = [sentence]
56
+ # 添加最后一个片段
57
+ if curr_chunk:
58
+ chunked_text.append(". ".join(curr_chunk))
59
+ return chunked_text[0]
60
+
61
+ def chunk_text_front(text, chunk_size = 2048):
62
+ '''
63
+ get the first `trunk_size` token of text
64
+ '''
65
+ chunked_text = ""
66
+ tokens = num_tokens_from_string(text)
67
+ if tokens < chunk_size:
68
+ return text
69
+ else:
70
+ ratio = float(chunk_size) / tokens
71
+ char_num = int(len(text) * ratio)
72
+ return text[:char_num]
73
+
74
+ def chunk_texts(text, chunk_size = 2048):
75
+ '''
76
+ trunk the text into n parts, return a list of text
77
+ [text, text, text]
78
+ '''
79
+ tokens = num_tokens_from_string(text)
80
+ if tokens < chunk_size:
81
+ return [text]
82
+ else:
83
+ texts = []
84
+ n = int(tokens/chunk_size) + 1
85
+ # 计算每个部分的长度
86
+ part_length = len(text) // n
87
+ # 如果不能整除,则最后一个部分会包含额外的字符
88
+ extra = len(text) % n
89
+ parts = []
90
+ start = 0
91
+
92
+ for i in range(n):
93
+ # 对于前extra个部分,每个部分多分配一个字符
94
+ end = start + part_length + (1 if i < extra else 0)
95
+ parts.append(text[start:end])
96
+ start = end
97
+ return parts
98
+
99
+ from datetime import datetime
100
+ from utils import *
101
+
102
+ from openai import OpenAI
103
+ import os
104
+
105
+ chatgpt_system_prompt = f'''
106
+ You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.
107
+ Knowledge cutoff: 2023-04
108
+ Current date: {datetime.now().strftime('%Y-%m-%d')}
109
+ '''
110
+
111
+ def get_draft(question):
112
+ # Getting the draft answer
113
+ draft_prompt = '''
114
+ IMPORTANT:
115
+ Try to answer this question/instruction with step-by-step thoughts and make the answer more structural.
116
+ Use `\n\n` to split the answer into several paragraphs.
117
+ Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
118
+ '''
119
+ openai_client = OpenAI()
120
+ draft = openai_client.chat.completions.create(
121
+ model="gpt-3.5-turbo",
122
+ messages=[
123
+ {
124
+ "role": "system",
125
+ "content": chatgpt_system_prompt
126
+ },
127
+ {
128
+ "role": "user",
129
+ "content": f"{question}" + draft_prompt
130
+ }
131
+ ],
132
+ temperature = 1.0
133
+ ).choices[0].message.content
134
+ return draft
135
+
136
+ def split_draft(draft, split_char = '\n\n'):
137
+ # 将draft切分为多个段落
138
+ # split_char: '\n\n'
139
+ draft_paragraphs = draft.split(split_char)
140
+ draft_paragraphs = [d for d in draft_paragraphs if d]
141
+ # print(f"The draft answer has {len(draft_paragraphs)}")
142
+ return draft_paragraphs
143
+
144
+ def get_query(question, answer):
145
+ query_prompt = '''
146
+ I want to verify the content correctness of the given question, especially the last sentences.
147
+ Please summarize the content with the corresponding question.
148
+ This summarization will be used as a query to search with Bing search engine.
149
+ The query should be short but need to be specific to promise Bing can find related knowledge or pages.
150
+ You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data.
151
+ Try to make the query as relevant as possible to the last few sentences in the content.
152
+ **IMPORTANT**
153
+ Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
154
+ '''
155
+ openai_client = OpenAI()
156
+ query = openai_client.chat.completions.create(
157
+ model="gpt-3.5-turbo",
158
+ messages=[
159
+ {
160
+ "role": "system",
161
+ "content": chatgpt_system_prompt
162
+ },
163
+ {
164
+ "role": "user",
165
+ "content": f"##Question: {question}\n\n##Content: {answer}\n\n##Instruction: {query_prompt}"
166
+ }
167
+ ],
168
+ temperature = 1.0
169
+ ).choices[0].message.content
170
+ return query
171
+
172
+ def get_content(query):
173
+ res = get_search(query, 1)
174
+ if not res:
175
+ print(">>> No good Google Search Result was found")
176
+ return None
177
+ search_results = res[0]
178
+ link = search_results['link'] # title, snippet
179
+ res = get_page_content(link)
180
+ if not res:
181
+ print(f">>> No content was found in {link}")
182
+ return None
183
+ retrieved_text = res
184
+ trunked_texts = chunk_texts(retrieved_text, 1500)
185
+ trunked_texts = [trunked_text.replace('\n', " ") for trunked_text in trunked_texts]
186
+ return trunked_texts
187
+
188
+ def get_revise_answer(question, answer, content):
189
+ revise_prompt = '''
190
+ I want to revise the answer according to retrieved related text of the question in WIKI pages.
191
+ You need to check whether the answer is correct.
192
+ If you find some errors in the answer, revise the answer to make it better.
193
+ If you find some necessary details are ignored, add it to make the answer more plausible according to the related text.
194
+ If you find the answer is right and do not need to add more details, just output the original answer directly.
195
+ **IMPORTANT**
196
+ Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding.
197
+ Add more details from retrieved text to the answer.
198
+ Split the paragraphs with `\n\n` characters.
199
+ Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to.
200
+ '''
201
+ openai_client = OpenAI()
202
+ revised_answer = openai_client.chat.completions.create(
203
+ model="gpt-3.5-turbo",
204
+ messages=[
205
+ {
206
+ "role": "system",
207
+ "content": chatgpt_system_prompt
208
+ },
209
+ {
210
+ "role": "user",
211
+ "content": f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}"
212
+ }
213
+ ],
214
+ temperature = 1.0
215
+ ).choices[0].message.content
216
+ return revised_answer
217
+
218
+ def get_query_wrapper(q, question, answer):
219
+ result = get_query(question, answer)
220
+ q.put(result) # 将结果放入队列
221
+
222
+ def get_content_wrapper(q, query):
223
+ result = get_content(query)
224
+ q.put(result) # 将结果放入队列
225
+
226
+ def get_revise_answer_wrapper(q, question, answer, content):
227
+ result = get_revise_answer(question, answer, content)
228
+ q.put(result)
229
+
230
+ from multiprocessing import Process, Queue
231
+ def run_with_timeout(func, timeout, *args, **kwargs):
232
+ q = Queue() # 创建一个Queue对象用于进程间通信
233
+ # 创建一个进程来执行传入的函数,将Queue和其他*args、**kwargs作为参数传递
234
+ p = Process(target=func, args=(q, *args), kwargs=kwargs)
235
+ p.start()
236
+ # 等待进程完成或超时
237
+ p.join(timeout)
238
+ if p.is_alive():
239
+ print(f"{datetime.now()} [INFO] 函数{str(func)}执行已超时({timeout}s),正在终止进程...")
240
+ p.terminate() # 终止进程
241
+ p.join() # 确保进程已经终止
242
+ result = None # 超时情况下,我们没有结果
243
+ else:
244
+ print(f"{datetime.now()} [INFO] 函数{str(func)}执行成功完成")
245
+ result = q.get() # 从队列中获取结果
246
+ return result
247
+
248
+ from difflib import unified_diff
249
+ from IPython.display import display, HTML
250
+
251
+ def generate_diff_html(text1, text2):
252
+ diff = unified_diff(text1.splitlines(keepends=True),
253
+ text2.splitlines(keepends=True),
254
+ fromfile='text1', tofile='text2')
255
 
256
+ diff_html = ""
257
+ for line in diff:
258
+ if line.startswith('+'):
259
+ diff_html += f"<div style='color:green;'>{line.rstrip()}</div>"
260
+ elif line.startswith('-'):
261
+ diff_html += f"<div style='color:red;'>{line.rstrip()}</div>"
262
+ elif line.startswith('@'):
263
+ diff_html += f"<div style='color:blue;'>{line.rstrip()}</div>"
264
+ else:
265
+ diff_html += f"{line.rstrip()}<br>"
266
+ return diff_html
267
 
268
+ newline_char = '\n'
269
+
270
+ def rat(question):
271
+ print(f"{datetime.now()} [INFO] 生成草稿中...")
272
+ draft = get_draft(question)
273
+ print(f"{datetime.now()} [INFO] 获得草稿")
274
+ # print(f"##################### DRAFT #######################")
275
+ # print(draft)
276
+ # print(f"##################### END #######################")
277
+
278
+ print(f"{datetime.now()} [INFO] 处理草稿...")
279
+ draft_paragraphs = split_draft(draft)
280
+ print(f"{datetime.now()} [INFO] 草稿被切分为{len(draft_paragraphs)}部分")
281
+ answer = ""
282
+ for i, p in enumerate(draft_paragraphs):
283
+ print(str(i)*80)
284
+ print(f"{datetime.now()} [INFO] 修改第{i+1}/{len(draft_paragraphs)}部分...")
285
+ answer = answer + '\n\n' + p
286
+ # print(f"[{i}/{len(draft_paragraphs)}] Original Answer:\n{answer.replace(newline_char, ' ')}")
287
+
288
+ # query = get_query(question, answer)
289
+ print(f"{datetime.now()} [INFO] 生成对应Query...")
290
+ res = run_with_timeout(get_query_wrapper, 3, question, answer)
291
+ if not res:
292
+ print(f"{datetime.now()} [INFO] 生成检索词超时,跳过后续步骤...")
293
+ continue
294
+ else:
295
+ query = res
296
+ print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}")
297
+
298
+ print(f"{datetime.now()} [INFO] 获取网页内容...")
299
+ # content = get_content(query)
300
+ res = run_with_timeout(get_content_wrapper, 5, query)
301
+ if not res:
302
+ print(f"{datetime.now()} [INFO] 获取网页内容超时,跳过后续步骤...")
303
+ continue
304
+ else:
305
+ content = res
306
+
307
+ for j, c in enumerate(content):
308
+ if j > 2:
309
+ break
310
+ print(f"{datetime.now()} [INFO] 根据网页内容修改对应答案...[{j}/{min(len(content),3)}]")
311
+ # answer = get_revise_answer(question, answer, c)
312
+ res = run_with_timeout(get_revise_answer_wrapper, 15, question, answer, c)
313
+ if not res:
314
+ print(f"{datetime.now()} [INFO] 修改答案超时,跳过后续步骤...")
315
+ continue
316
+ else:
317
+ diff_html = generate_diff_html(answer, res)
318
+ display(HTML(diff_html))
319
+ answer = res
320
+ print(f"{datetime.now()} [INFO] 答案修改完成[{j}/{min(len(content),3)}]")
321
+ # print(f"[{i}/{len(draft_paragraphs)}] REVISED ANSWER:\n {answer.replace(newline_char, ' ')}")
322
+ # print()
323
+ return draft, answer
324
+ # return answer
325
+
326
+ from utils import *
327
+
328
+ page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation"
329
+ page_md = """
330
+ # RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation
331
+
332
+ We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated.
333
+
334
+ Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning.
335
+
336
+ Feel free to try our demo!
337
+
338
+ """
339
+
340
+ def clear_func():
341
+ return "", "", ""
342
+
343
+ def set_openai_api_key(api_key):
344
+ if api_key and api_key.startswith("sk-") and len(api_key) > 50:
345
+ import openai
346
+ openai.api_key = api_key
347
+
348
+ with gr.Blocks(title = page_title) as demo:
349
+
350
+ gr.Markdown(page_md)
351
+
352
+ with gr.Row():
353
+ chatgpt_box = gr.Textbox(
354
+ label = "ChatGPT",
355
+ placeholder = "Response from ChatGPT with zero-shot chain-of-thought.",
356
+ elem_id = "chatgpt"
357
+ )
358
+
359
+ with gr.Row():
360
+ stream_box = gr.Textbox(
361
+ label = "Streaming",
362
+ placeholder = "Interactive response with RAT...",
363
+ elem_id = "stream",
364
+ lines = 10,
365
+ visible = False
366
+ )
367
 
368
+ with gr.Row():
369
+ rat_box = gr.Textbox(
370
+ label = "RAT",
371
+ placeholder = "Final response with RAT ...",
372
+ elem_id = "rat",
373
+ lines = 6
374
+ )
375
+
376
+ with gr.Column(elem_id="instruction_row"):
377
+ with gr.Row():
378
+ instruction_box = gr.Textbox(
379
+ label = "instruction",
380
+ placeholder = "Enter your instruction here",
381
+ lines = 2,
382
+ elem_id="instruction",
383
+ interactive=True,
384
+ visible=True
385
+ )
386
+ with gr.Row():
387
+ model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo",
388
+ label='GPT model: ', show_label=True,interactive=True, visible=True)
389
+ openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
390
+ show_label=False, lines=1, type='password')
391
+
392
+ openai_api_key_textbox.change(set_openai_api_key,
393
+ inputs=[openai_api_key_textbox],
394
+ outputs=[])
395
+
396
+ with gr.Row():
397
+ submit_btn = gr.Button(
398
+ value="submit", visible=True, interactive=True
399
+ )
400
+ clear_btn = gr.Button(
401
+ value="clear", visible=True, interactive=True
402
+ )
403
+ regenerate_btn = gr.Button(
404
+ value="regenerate", visible=True, interactive=True
405
+ )
406
+
407
+ submit_btn.click(
408
+ fn = rat,
409
+ inputs = [instruction_box],
410
+ outputs = [chatgpt_box, rat_box]
411
+ )
412
+
413
+ clear_btn.click(
414
+ fn = clear_func,
415
+ inputs = [],
416
+ outputs = [instruction_box, chatgpt_box, rat_box]
417
+ )
418
+
419
+ regenerate_btn.click(
420
+ fn = rat,
421
+ inputs = [instruction_box],
422
+ outputs = [chatgpt_box, rat_box]
423
+ )
424
+
425
+ examples = gr.Examples(
426
+ examples=[
427
+ "I went to the supermarket yesterday.",
428
+ "Helen is a good swimmer."],
429
+ inputs=[instruction_box]
430
+ )
431
+
432
+ demo.launch(server_name="0.0.0.0", debug=True)