File size: 16,526 Bytes
bffea63 3119d85 bffea63 3119d85 bffea63 3119d85 1097625 3119d85 a57c21c f0d180c 3119d85 a57c21c f0d180c 3119d85 a57c21c f0d180c 3119d85 bffea63 3119d85 bffea63 3119d85 a57c21c 3119d85 bffea63 3119d85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 |
import gradio as gr
from langchain.tools import Tool
from langchain_community.utilities import GoogleSearchAPIWrapper
import os
from langchain.tools import Tool
from langchain_community.utilities import GoogleSearchAPIWrapper
def get_search(query:str="", k:int=1): # get the top-k resources with google
search = GoogleSearchAPIWrapper(k=k)
def search_results(query):
return search.results(query, k)
tool = Tool(
name="Google Search Snippets",
description="Search Google for recent results.",
func=search_results,
)
ref_text = tool.run(query)
if 'Result' not in ref_text[0].keys():
return ref_text
else:
return None
from langchain_community.document_transformers import Html2TextTransformer
from langchain_community.document_loaders import AsyncHtmlLoader
def get_page_content(link:str):
loader = AsyncHtmlLoader([link])
docs = loader.load()
html2text = Html2TextTransformer()
docs_transformed = html2text.transform_documents(docs)
if len(docs_transformed) > 0:
return docs_transformed[0].page_content
else:
return None
import tiktoken
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def chunk_text_by_sentence(text, chunk_size=2048):
"""Chunk the $text into sentences with less than 2k tokens."""
sentences = text.split('. ')
chunked_text = []
curr_chunk = []
# 逐句添加文本片段,确保每个段落都小于2k个token
for sentence in sentences:
if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size:
curr_chunk.append(sentence)
else:
chunked_text.append(". ".join(curr_chunk))
curr_chunk = [sentence]
# 添加最后一个片段
if curr_chunk:
chunked_text.append(". ".join(curr_chunk))
return chunked_text[0]
def chunk_text_front(text, chunk_size = 2048):
'''
get the first `trunk_size` token of text
'''
chunked_text = ""
tokens = num_tokens_from_string(text)
if tokens < chunk_size:
return text
else:
ratio = float(chunk_size) / tokens
char_num = int(len(text) * ratio)
return text[:char_num]
def chunk_texts(text, chunk_size = 2048):
'''
trunk the text into n parts, return a list of text
[text, text, text]
'''
tokens = num_tokens_from_string(text)
if tokens < chunk_size:
return [text]
else:
texts = []
n = int(tokens/chunk_size) + 1
# 计算每个部分的长度
part_length = len(text) // n
# 如果不能整除,则最后一个部分会包含额外的字符
extra = len(text) % n
parts = []
start = 0
for i in range(n):
# 对于前extra个部分,每个部分多分配一个字符
end = start + part_length + (1 if i < extra else 0)
parts.append(text[start:end])
start = end
return parts
from datetime import datetime
from openai import OpenAI
import openai
import os
chatgpt_system_prompt = f'''
You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.
Knowledge cutoff: 2023-04
Current date: {datetime.now().strftime('%Y-%m-%d')}
'''
def get_draft(question):
# Getting the draft answer
draft_prompt = '''
IMPORTANT:
Try to answer this question/instruction with step-by-step thoughts and make the answer more structural.
Use `\n\n` to split the answer into several paragraphs.
Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
# openai_client = OpenAI(api_key=openai.api_key)
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
draft = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": chatgpt_system_prompt
},
{
"role": "user",
"content": f"{question}" + draft_prompt
}
],
temperature = 1.0
).choices[0].message.content
return draft
def split_draft(draft, split_char = '\n\n'):
# 将draft切分为多个段落
# split_char: '\n\n'
draft_paragraphs = draft.split(split_char)
draft_paragraphs = [d for d in draft_paragraphs if d]
# print(f"The draft answer has {len(draft_paragraphs)}")
return draft_paragraphs
def get_query(question, answer):
query_prompt = '''
I want to verify the content correctness of the given question, especially the last sentences.
Please summarize the content with the corresponding question.
This summarization will be used as a query to search with Bing search engine.
The query should be short but need to be specific to promise Bing can find related knowledge or pages.
You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data.
Try to make the query as relevant as possible to the last few sentences in the content.
**IMPORTANT**
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
# openai_client = OpenAI(api_key = openai.api_key)
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
query = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": chatgpt_system_prompt
},
{
"role": "user",
"content": f"##Question: {question}\n\n##Content: {answer}\n\n##Instruction: {query_prompt}"
}
],
temperature = 1.0
).choices[0].message.content
return query
def get_content(query):
res = get_search(query, 1)
if not res:
print(">>> No good Google Search Result was found")
return None
search_results = res[0]
link = search_results['link'] # title, snippet
res = get_page_content(link)
if not res:
print(f">>> No content was found in {link}")
return None
retrieved_text = res
trunked_texts = chunk_texts(retrieved_text, 1500)
trunked_texts = [trunked_text.replace('\n', " ") for trunked_text in trunked_texts]
return trunked_texts
def get_revise_answer(question, answer, content):
revise_prompt = '''
I want to revise the answer according to retrieved related text of the question in WIKI pages.
You need to check whether the answer is correct.
If you find some errors in the answer, revise the answer to make it better.
If you find some necessary details are ignored, add it to make the answer more plausible according to the related text.
If you find the answer is right and do not need to add more details, just output the original answer directly.
**IMPORTANT**
Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding.
Add more details from retrieved text to the answer.
Split the paragraphs with `\n\n` characters.
Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to.
'''
# openai_client = OpenAI(api_key = openai.api_key)
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
revised_answer = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": chatgpt_system_prompt
},
{
"role": "user",
"content": f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}"
}
],
temperature = 1.0
).choices[0].message.content
return revised_answer
def get_query_wrapper(q, question, answer):
result = get_query(question, answer)
q.put(result) # 将结果放入队列
def get_content_wrapper(q, query):
result = get_content(query)
q.put(result) # 将结果放入队列
def get_revise_answer_wrapper(q, question, answer, content):
result = get_revise_answer(question, answer, content)
q.put(result)
from multiprocessing import Process, Queue
def run_with_timeout(func, timeout, *args, **kwargs):
q = Queue() # 创建一个Queue对象用于进程间通信
# 创建一个进程来执行传入的函数,将Queue和其他*args、**kwargs作为参数传递
p = Process(target=func, args=(q, *args), kwargs=kwargs)
p.start()
# 等待进程完成或超时
p.join(timeout)
if p.is_alive():
print(f"{datetime.now()} [INFO] 函数{str(func)}执行已超时({timeout}s),正在终止进程...")
p.terminate() # 终止进程
p.join() # 确保进程已经终止
result = None # 超时情况下,我们没有结果
else:
print(f"{datetime.now()} [INFO] 函数{str(func)}执行成功完成")
result = q.get() # 从队列中获取结果
return result
from difflib import unified_diff
from IPython.display import display, HTML
def generate_diff_html(text1, text2):
diff = unified_diff(text1.splitlines(keepends=True),
text2.splitlines(keepends=True),
fromfile='text1', tofile='text2')
diff_html = ""
for line in diff:
if line.startswith('+'):
diff_html += f"<div style='color:green;'>{line.rstrip()}</div>"
elif line.startswith('-'):
diff_html += f"<div style='color:red;'>{line.rstrip()}</div>"
elif line.startswith('@'):
diff_html += f"<div style='color:blue;'>{line.rstrip()}</div>"
else:
diff_html += f"{line.rstrip()}<br>"
return diff_html
newline_char = '\n'
def rat(question):
print(f"{datetime.now()} [INFO] 生成草稿中...")
draft = get_draft(question)
print(f"{datetime.now()} [INFO] 获得草稿")
# print(f"##################### DRAFT #######################")
# print(draft)
# print(f"##################### END #######################")
print(f"{datetime.now()} [INFO] 处理草稿...")
draft_paragraphs = split_draft(draft)
print(f"{datetime.now()} [INFO] 草稿被切分为{len(draft_paragraphs)}部分")
answer = ""
for i, p in enumerate(draft_paragraphs):
print(str(i)*80)
print(f"{datetime.now()} [INFO] 修改第{i+1}/{len(draft_paragraphs)}部分...")
answer = answer + '\n\n' + p
# print(f"[{i}/{len(draft_paragraphs)}] Original Answer:\n{answer.replace(newline_char, ' ')}")
# query = get_query(question, answer)
print(f"{datetime.now()} [INFO] 生成对应Query...")
res = run_with_timeout(get_query_wrapper, 3, question, answer)
if not res:
print(f"{datetime.now()} [INFO] 生成检索词超时,跳过后续步骤...")
continue
else:
query = res
print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}")
print(f"{datetime.now()} [INFO] 获取网页内容...")
# content = get_content(query)
res = run_with_timeout(get_content_wrapper, 5, query)
if not res:
print(f"{datetime.now()} [INFO] 获取网页内容超时,跳过后续步骤...")
continue
else:
content = res
for j, c in enumerate(content):
if j > 2:
break
print(f"{datetime.now()} [INFO] 根据网页内容修改对应答案...[{j}/{min(len(content),3)}]")
# answer = get_revise_answer(question, answer, c)
res = run_with_timeout(get_revise_answer_wrapper, 15, question, answer, c)
if not res:
print(f"{datetime.now()} [INFO] 修改答案超时,跳过后续步骤...")
continue
else:
diff_html = generate_diff_html(answer, res)
display(HTML(diff_html))
answer = res
print(f"{datetime.now()} [INFO] 答案修改完成[{j}/{min(len(content),3)}]")
# print(f"[{i}/{len(draft_paragraphs)}] REVISED ANSWER:\n {answer.replace(newline_char, ' ')}")
# print()
return draft, answer
# return answer
page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation"
page_md = """
# RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation
We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated.
Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning.
Feel free to try our demo!
"""
def clear_func():
return "", "", ""
def set_openai_api_key(api_key):
if api_key and api_key.startswith("sk-") and len(api_key) > 50:
os.environ["OPENAI_API_KEY"] = api_key
with gr.Blocks(title = page_title) as demo:
gr.Markdown(page_md)
with gr.Row():
chatgpt_box = gr.Textbox(
label = "ChatGPT",
placeholder = "Response from ChatGPT with zero-shot chain-of-thought.",
elem_id = "chatgpt"
)
with gr.Row():
stream_box = gr.Textbox(
label = "Streaming",
placeholder = "Interactive response with RAT...",
elem_id = "stream",
lines = 10,
visible = False
)
with gr.Row():
rat_box = gr.Textbox(
label = "RAT",
placeholder = "Final response with RAT ...",
elem_id = "rat",
lines = 6
)
with gr.Column(elem_id="instruction_row"):
with gr.Row():
instruction_box = gr.Textbox(
label = "instruction",
placeholder = "Enter your instruction here",
lines = 2,
elem_id="instruction",
interactive=True,
visible=True
)
with gr.Row():
model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo",
label='GPT model: ', show_label=True,interactive=True, visible=True)
openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
show_label=False, lines=1, type='password')
openai_api_key_textbox.change(set_openai_api_key,
inputs=[openai_api_key_textbox],
outputs=[])
with gr.Row():
submit_btn = gr.Button(
value="submit", visible=True, interactive=True
)
clear_btn = gr.Button(
value="clear", visible=True, interactive=True
)
regenerate_btn = gr.Button(
value="regenerate", visible=True, interactive=True
)
submit_btn.click(
fn = rat,
inputs = [instruction_box],
outputs = [chatgpt_box, rat_box]
)
clear_btn.click(
fn = clear_func,
inputs = [],
outputs = [instruction_box, chatgpt_box, rat_box]
)
regenerate_btn.click(
fn = rat,
inputs = [instruction_box],
outputs = [chatgpt_box, rat_box]
)
examples = gr.Examples(
examples=[
"I went to the supermarket yesterday.",
"Helen is a good swimmer."],
inputs=[instruction_box]
)
demo.launch(server_name="0.0.0.0", debug=True) |