|
import gradio as gr |
|
from langchain.tools import Tool |
|
from langchain_community.utilities import GoogleSearchAPIWrapper |
|
import os |
|
|
|
from langchain.tools import Tool |
|
from langchain_community.utilities import GoogleSearchAPIWrapper |
|
|
|
|
|
def get_search(query:str="", k:int=1): |
|
search = GoogleSearchAPIWrapper(k=k) |
|
def search_results(query): |
|
return search.results(query, k) |
|
tool = Tool( |
|
name="Google Search Snippets", |
|
description="Search Google for recent results.", |
|
func=search_results, |
|
) |
|
ref_text = tool.run(query) |
|
if 'Result' not in ref_text[0].keys(): |
|
return ref_text |
|
else: |
|
return None |
|
|
|
from langchain_community.document_transformers import Html2TextTransformer |
|
from langchain_community.document_loaders import AsyncHtmlLoader |
|
def get_page_content(link:str): |
|
loader = AsyncHtmlLoader([link]) |
|
docs = loader.load() |
|
html2text = Html2TextTransformer() |
|
docs_transformed = html2text.transform_documents(docs) |
|
if len(docs_transformed) > 0: |
|
return docs_transformed[0].page_content |
|
else: |
|
return None |
|
|
|
import tiktoken |
|
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: |
|
"""Returns the number of tokens in a text string.""" |
|
encoding = tiktoken.get_encoding(encoding_name) |
|
num_tokens = len(encoding.encode(string)) |
|
return num_tokens |
|
|
|
def chunk_text_by_sentence(text, chunk_size=2048): |
|
"""Chunk the $text into sentences with less than 2k tokens.""" |
|
sentences = text.split('. ') |
|
chunked_text = [] |
|
curr_chunk = [] |
|
|
|
for sentence in sentences: |
|
if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size: |
|
curr_chunk.append(sentence) |
|
else: |
|
chunked_text.append(". ".join(curr_chunk)) |
|
curr_chunk = [sentence] |
|
|
|
if curr_chunk: |
|
chunked_text.append(". ".join(curr_chunk)) |
|
return chunked_text[0] |
|
|
|
def chunk_text_front(text, chunk_size = 2048): |
|
''' |
|
get the first `trunk_size` token of text |
|
''' |
|
chunked_text = "" |
|
tokens = num_tokens_from_string(text) |
|
if tokens < chunk_size: |
|
return text |
|
else: |
|
ratio = float(chunk_size) / tokens |
|
char_num = int(len(text) * ratio) |
|
return text[:char_num] |
|
|
|
def chunk_texts(text, chunk_size = 2048): |
|
''' |
|
trunk the text into n parts, return a list of text |
|
[text, text, text] |
|
''' |
|
tokens = num_tokens_from_string(text) |
|
if tokens < chunk_size: |
|
return [text] |
|
else: |
|
texts = [] |
|
n = int(tokens/chunk_size) + 1 |
|
|
|
part_length = len(text) // n |
|
|
|
extra = len(text) % n |
|
parts = [] |
|
start = 0 |
|
|
|
for i in range(n): |
|
|
|
end = start + part_length + (1 if i < extra else 0) |
|
parts.append(text[start:end]) |
|
start = end |
|
return parts |
|
|
|
from datetime import datetime |
|
|
|
from openai import OpenAI |
|
import openai |
|
import os |
|
|
|
chatgpt_system_prompt = f''' |
|
You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture. |
|
Knowledge cutoff: 2023-04 |
|
Current date: {datetime.now().strftime('%Y-%m-%d')} |
|
''' |
|
|
|
def get_draft(question): |
|
|
|
draft_prompt = ''' |
|
IMPORTANT: |
|
Try to answer this question/instruction with step-by-step thoughts and make the answer more structural. |
|
Use `\n\n` to split the answer into several paragraphs. |
|
Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to. |
|
''' |
|
|
|
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) |
|
draft = openai_client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": chatgpt_system_prompt |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"{question}" + draft_prompt |
|
} |
|
], |
|
temperature = 1.0 |
|
).choices[0].message.content |
|
return draft |
|
|
|
def split_draft(draft, split_char = '\n\n'): |
|
|
|
|
|
draft_paragraphs = draft.split(split_char) |
|
draft_paragraphs = [d for d in draft_paragraphs if d] |
|
|
|
return draft_paragraphs |
|
|
|
def get_query(question, answer): |
|
query_prompt = ''' |
|
I want to verify the content correctness of the given question, especially the last sentences. |
|
Please summarize the content with the corresponding question. |
|
This summarization will be used as a query to search with Bing search engine. |
|
The query should be short but need to be specific to promise Bing can find related knowledge or pages. |
|
You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data. |
|
Try to make the query as relevant as possible to the last few sentences in the content. |
|
**IMPORTANT** |
|
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to. |
|
''' |
|
|
|
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) |
|
query = openai_client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": chatgpt_system_prompt |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"##Question: {question}\n\n##Content: {answer}\n\n##Instruction: {query_prompt}" |
|
} |
|
], |
|
temperature = 1.0 |
|
).choices[0].message.content |
|
return query |
|
|
|
def get_content(query): |
|
res = get_search(query, 1) |
|
if not res: |
|
print(">>> No good Google Search Result was found") |
|
return None |
|
search_results = res[0] |
|
link = search_results['link'] |
|
res = get_page_content(link) |
|
if not res: |
|
print(f">>> No content was found in {link}") |
|
return None |
|
retrieved_text = res |
|
trunked_texts = chunk_texts(retrieved_text, 1500) |
|
trunked_texts = [trunked_text.replace('\n', " ") for trunked_text in trunked_texts] |
|
return trunked_texts |
|
|
|
def get_revise_answer(question, answer, content): |
|
revise_prompt = ''' |
|
I want to revise the answer according to retrieved related text of the question in WIKI pages. |
|
You need to check whether the answer is correct. |
|
If you find some errors in the answer, revise the answer to make it better. |
|
If you find some necessary details are ignored, add it to make the answer more plausible according to the related text. |
|
If you find the answer is right and do not need to add more details, just output the original answer directly. |
|
**IMPORTANT** |
|
Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding. |
|
Add more details from retrieved text to the answer. |
|
Split the paragraphs with `\n\n` characters. |
|
Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to. |
|
''' |
|
|
|
openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) |
|
revised_answer = openai_client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": chatgpt_system_prompt |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}" |
|
} |
|
], |
|
temperature = 1.0 |
|
).choices[0].message.content |
|
return revised_answer |
|
|
|
def get_query_wrapper(q, question, answer): |
|
result = get_query(question, answer) |
|
q.put(result) |
|
|
|
def get_content_wrapper(q, query): |
|
result = get_content(query) |
|
q.put(result) |
|
|
|
def get_revise_answer_wrapper(q, question, answer, content): |
|
result = get_revise_answer(question, answer, content) |
|
q.put(result) |
|
|
|
from multiprocessing import Process, Queue |
|
def run_with_timeout(func, timeout, *args, **kwargs): |
|
q = Queue() |
|
|
|
p = Process(target=func, args=(q, *args), kwargs=kwargs) |
|
p.start() |
|
|
|
p.join(timeout) |
|
if p.is_alive(): |
|
print(f"{datetime.now()} [INFO] 函数{str(func)}执行已超时({timeout}s),正在终止进程...") |
|
p.terminate() |
|
p.join() |
|
result = None |
|
else: |
|
print(f"{datetime.now()} [INFO] 函数{str(func)}执行成功完成") |
|
result = q.get() |
|
return result |
|
|
|
from difflib import unified_diff |
|
from IPython.display import display, HTML |
|
|
|
def generate_diff_html(text1, text2): |
|
diff = unified_diff(text1.splitlines(keepends=True), |
|
text2.splitlines(keepends=True), |
|
fromfile='text1', tofile='text2') |
|
|
|
diff_html = "" |
|
for line in diff: |
|
if line.startswith('+'): |
|
diff_html += f"<div style='color:green;'>{line.rstrip()}</div>" |
|
elif line.startswith('-'): |
|
diff_html += f"<div style='color:red;'>{line.rstrip()}</div>" |
|
elif line.startswith('@'): |
|
diff_html += f"<div style='color:blue;'>{line.rstrip()}</div>" |
|
else: |
|
diff_html += f"{line.rstrip()}<br>" |
|
return diff_html |
|
|
|
newline_char = '\n' |
|
|
|
def rat(question): |
|
print(f"{datetime.now()} [INFO] 生成草稿中...") |
|
draft = get_draft(question) |
|
print(f"{datetime.now()} [INFO] 获得草稿") |
|
|
|
|
|
|
|
|
|
print(f"{datetime.now()} [INFO] 处理草稿...") |
|
draft_paragraphs = split_draft(draft) |
|
print(f"{datetime.now()} [INFO] 草稿被切分为{len(draft_paragraphs)}部分") |
|
answer = "" |
|
for i, p in enumerate(draft_paragraphs): |
|
print(str(i)*80) |
|
print(f"{datetime.now()} [INFO] 修改第{i+1}/{len(draft_paragraphs)}部分...") |
|
answer = answer + '\n\n' + p |
|
|
|
|
|
|
|
print(f"{datetime.now()} [INFO] 生成对应Query...") |
|
res = run_with_timeout(get_query_wrapper, 3, question, answer) |
|
if not res: |
|
print(f"{datetime.now()} [INFO] 生成检索词超时,跳过后续步骤...") |
|
continue |
|
else: |
|
query = res |
|
print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}") |
|
|
|
print(f"{datetime.now()} [INFO] 获取网页内容...") |
|
|
|
res = run_with_timeout(get_content_wrapper, 5, query) |
|
if not res: |
|
print(f"{datetime.now()} [INFO] 获取网页内容超时,跳过后续步骤...") |
|
continue |
|
else: |
|
content = res |
|
|
|
for j, c in enumerate(content): |
|
if j > 2: |
|
break |
|
print(f"{datetime.now()} [INFO] 根据网页内容修改对应答案...[{j}/{min(len(content),3)}]") |
|
|
|
res = run_with_timeout(get_revise_answer_wrapper, 15, question, answer, c) |
|
if not res: |
|
print(f"{datetime.now()} [INFO] 修改答案超时,跳过后续步骤...") |
|
continue |
|
else: |
|
diff_html = generate_diff_html(answer, res) |
|
display(HTML(diff_html)) |
|
answer = res |
|
print(f"{datetime.now()} [INFO] 答案修改完成[{j}/{min(len(content),3)}]") |
|
|
|
|
|
return draft, answer |
|
|
|
|
|
page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation" |
|
page_md = """ |
|
# RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation |
|
|
|
We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated. |
|
|
|
Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning. |
|
|
|
Feel free to try our demo! |
|
|
|
""" |
|
|
|
def clear_func(): |
|
return "", "", "" |
|
|
|
def set_openai_api_key(api_key): |
|
if api_key and api_key.startswith("sk-") and len(api_key) > 50: |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
with gr.Blocks(title = page_title) as demo: |
|
|
|
gr.Markdown(page_md) |
|
|
|
with gr.Row(): |
|
chatgpt_box = gr.Textbox( |
|
label = "ChatGPT", |
|
placeholder = "Response from ChatGPT with zero-shot chain-of-thought.", |
|
elem_id = "chatgpt" |
|
) |
|
|
|
with gr.Row(): |
|
stream_box = gr.Textbox( |
|
label = "Streaming", |
|
placeholder = "Interactive response with RAT...", |
|
elem_id = "stream", |
|
lines = 10, |
|
visible = False |
|
) |
|
|
|
with gr.Row(): |
|
rat_box = gr.Textbox( |
|
label = "RAT", |
|
placeholder = "Final response with RAT ...", |
|
elem_id = "rat", |
|
lines = 6 |
|
) |
|
|
|
with gr.Column(elem_id="instruction_row"): |
|
with gr.Row(): |
|
instruction_box = gr.Textbox( |
|
label = "instruction", |
|
placeholder = "Enter your instruction here", |
|
lines = 2, |
|
elem_id="instruction", |
|
interactive=True, |
|
visible=True |
|
) |
|
with gr.Row(): |
|
model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo", |
|
label='GPT model: ', show_label=True,interactive=True, visible=True) |
|
openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter", |
|
show_label=False, lines=1, type='password') |
|
|
|
openai_api_key_textbox.change(set_openai_api_key, |
|
inputs=[openai_api_key_textbox], |
|
outputs=[]) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button( |
|
value="submit", visible=True, interactive=True |
|
) |
|
clear_btn = gr.Button( |
|
value="clear", visible=True, interactive=True |
|
) |
|
regenerate_btn = gr.Button( |
|
value="regenerate", visible=True, interactive=True |
|
) |
|
|
|
submit_btn.click( |
|
fn = rat, |
|
inputs = [instruction_box], |
|
outputs = [chatgpt_box, rat_box] |
|
) |
|
|
|
clear_btn.click( |
|
fn = clear_func, |
|
inputs = [], |
|
outputs = [instruction_box, chatgpt_box, rat_box] |
|
) |
|
|
|
regenerate_btn.click( |
|
fn = rat, |
|
inputs = [instruction_box], |
|
outputs = [chatgpt_box, rat_box] |
|
) |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
"I went to the supermarket yesterday.", |
|
"Helen is a good swimmer."], |
|
inputs=[instruction_box] |
|
) |
|
|
|
demo.launch(server_name="0.0.0.0", debug=True) |