File size: 16,526 Bytes
bffea63
 
 
3119d85
 
 
 
bffea63
3119d85
 
bffea63
 
 
 
 
 
 
 
 
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1097625
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57c21c
f0d180c
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57c21c
f0d180c
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57c21c
f0d180c
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bffea63
3119d85
 
 
 
 
 
 
 
 
 
 
bffea63
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57c21c
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bffea63
3119d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import gradio as gr
from langchain.tools import Tool
from langchain_community.utilities import GoogleSearchAPIWrapper
import os 

from langchain.tools import Tool
from langchain_community.utilities import GoogleSearchAPIWrapper


def get_search(query:str="", k:int=1): # get the top-k resources with google
    search = GoogleSearchAPIWrapper(k=k)
    def search_results(query):
        return search.results(query, k)
    tool = Tool(
        name="Google Search Snippets",
        description="Search Google for recent results.",
        func=search_results,
    )
    ref_text = tool.run(query)
    if 'Result' not in ref_text[0].keys():
        return ref_text
    else:
        return None

from langchain_community.document_transformers import Html2TextTransformer
from langchain_community.document_loaders import AsyncHtmlLoader
def get_page_content(link:str):
    loader = AsyncHtmlLoader([link])
    docs = loader.load()
    html2text = Html2TextTransformer()
    docs_transformed = html2text.transform_documents(docs)
    if len(docs_transformed) > 0:
        return docs_transformed[0].page_content
    else:
        return None

import tiktoken
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

def chunk_text_by_sentence(text, chunk_size=2048):
    """Chunk the $text into sentences with less than 2k tokens."""
    sentences = text.split('. ')
    chunked_text = []
    curr_chunk = []
    # 逐句添加文本片段,确保每个段落都小于2k个token
    for sentence in sentences:
        if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size:
            curr_chunk.append(sentence)
        else:
            chunked_text.append(". ".join(curr_chunk))
            curr_chunk = [sentence]
    # 添加最后一个片段
    if curr_chunk:
        chunked_text.append(". ".join(curr_chunk))
    return chunked_text[0]

def chunk_text_front(text, chunk_size = 2048):
    '''
    get the first `trunk_size` token of text
    '''
    chunked_text = ""
    tokens = num_tokens_from_string(text)
    if tokens < chunk_size:
        return text
    else:
        ratio = float(chunk_size) / tokens
        char_num = int(len(text) * ratio)
        return text[:char_num]

def chunk_texts(text, chunk_size = 2048):
    '''
    trunk the text into n parts, return a list of text
    [text, text, text]
    '''
    tokens = num_tokens_from_string(text)
    if tokens < chunk_size:
        return [text]
    else:
        texts = []
        n = int(tokens/chunk_size) + 1
        # 计算每个部分的长度
        part_length = len(text) // n
        # 如果不能整除,则最后一个部分会包含额外的字符
        extra = len(text) % n
        parts = []
        start = 0

        for i in range(n):
            # 对于前extra个部分,每个部分多分配一个字符
            end = start + part_length + (1 if i < extra else 0)
            parts.append(text[start:end])
            start = end
        return parts
    
from datetime import datetime

from openai import OpenAI
import openai
import os

chatgpt_system_prompt = f'''
You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.
Knowledge cutoff: 2023-04
Current date: {datetime.now().strftime('%Y-%m-%d')}
'''

def get_draft(question):
    # Getting the draft answer
    draft_prompt = '''
IMPORTANT:
Try to answer this question/instruction with step-by-step thoughts and make the answer more structural.
Use `\n\n` to split the answer into several paragraphs.
Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key=openai.api_key)
    openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
    draft = openai_client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {
                "role": "system",
                "content": chatgpt_system_prompt
            },
            {
                "role": "user",
                "content": f"{question}" + draft_prompt
            }
        ],
        temperature = 1.0
    ).choices[0].message.content
    return draft

def split_draft(draft, split_char = '\n\n'):
    # 将draft切分为多个段落
    # split_char: '\n\n'
    draft_paragraphs = draft.split(split_char)
    draft_paragraphs = [d for d in draft_paragraphs if d]
    # print(f"The draft answer has {len(draft_paragraphs)}")
    return draft_paragraphs

def get_query(question, answer):
    query_prompt = '''
I want to verify the content correctness of the given question, especially the last sentences.
Please summarize the content with the corresponding question.
This summarization will be used as a query to search with Bing search engine.
The query should be short but need to be specific to promise Bing can find related knowledge or pages.
You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data.
Try to make the query as relevant as possible to the last few sentences in the content.
**IMPORTANT**
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key = openai.api_key)
    openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
    query = openai_client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {
                "role": "system",
                "content": chatgpt_system_prompt
            },
            {
                "role": "user",
                "content": f"##Question: {question}\n\n##Content: {answer}\n\n##Instruction: {query_prompt}"
            }
        ],
        temperature = 1.0
    ).choices[0].message.content
    return query

def get_content(query):
    res = get_search(query, 1)
    if not res:
        print(">>> No good Google Search Result was found")
        return None
    search_results = res[0]
    link = search_results['link'] # title, snippet
    res = get_page_content(link)
    if not res:
        print(f">>> No content was found in {link}")
        return None
    retrieved_text = res
    trunked_texts = chunk_texts(retrieved_text, 1500)
    trunked_texts = [trunked_text.replace('\n', " ") for trunked_text in trunked_texts]
    return trunked_texts

def get_revise_answer(question, answer, content):
    revise_prompt = '''
I want to revise the answer according to retrieved related text of the question in WIKI pages.
You need to check whether the answer is correct.
If you find some errors in the answer, revise the answer to make it better.
If you find some necessary details are ignored, add it to make the answer more plausible according to the related text.
If you find the answer is right and do not need to add more details, just output the original answer directly.
**IMPORTANT**
Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding.
Add more details from retrieved text to the answer.
Split the paragraphs with `\n\n` characters.
Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key = openai.api_key)
    openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
    revised_answer = openai_client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
                {
                    "role": "system",
                    "content": chatgpt_system_prompt
                },
                {
                    "role": "user",
                    "content": f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}"
                }
            ],
            temperature = 1.0
    ).choices[0].message.content
    return revised_answer

def get_query_wrapper(q, question, answer):
    result = get_query(question, answer)
    q.put(result)  # 将结果放入队列

def get_content_wrapper(q, query):
    result = get_content(query)
    q.put(result)  # 将结果放入队列

def get_revise_answer_wrapper(q, question, answer, content):
    result = get_revise_answer(question, answer, content)
    q.put(result)

from multiprocessing import Process, Queue
def run_with_timeout(func, timeout, *args, **kwargs):
    q = Queue()  # 创建一个Queue对象用于进程间通信
    # 创建一个进程来执行传入的函数,将Queue和其他*args、**kwargs作为参数传递
    p = Process(target=func, args=(q, *args), kwargs=kwargs)
    p.start()
    # 等待进程完成或超时
    p.join(timeout)
    if p.is_alive():
        print(f"{datetime.now()} [INFO] 函数{str(func)}执行已超时({timeout}s),正在终止进程...")
        p.terminate()  # 终止进程
        p.join()  # 确保进程已经终止
        result = None  # 超时情况下,我们没有结果
    else:
        print(f"{datetime.now()} [INFO] 函数{str(func)}执行成功完成")
        result = q.get()  # 从队列中获取结果
    return result

from difflib import unified_diff
from IPython.display import display, HTML

def generate_diff_html(text1, text2):
    diff = unified_diff(text1.splitlines(keepends=True),
                        text2.splitlines(keepends=True),
                        fromfile='text1', tofile='text2')

    diff_html = ""
    for line in diff:
        if line.startswith('+'):
            diff_html += f"<div style='color:green;'>{line.rstrip()}</div>"
        elif line.startswith('-'):
            diff_html += f"<div style='color:red;'>{line.rstrip()}</div>"
        elif line.startswith('@'):
            diff_html += f"<div style='color:blue;'>{line.rstrip()}</div>"
        else:
            diff_html += f"{line.rstrip()}<br>"
    return diff_html

newline_char = '\n'

def rat(question):
    print(f"{datetime.now()} [INFO] 生成草稿中...")
    draft = get_draft(question)
    print(f"{datetime.now()} [INFO] 获得草稿")
    # print(f"##################### DRAFT #######################")
    # print(draft)
    # print(f"#####################  END  #######################")

    print(f"{datetime.now()} [INFO] 处理草稿...")
    draft_paragraphs = split_draft(draft)
    print(f"{datetime.now()} [INFO] 草稿被切分为{len(draft_paragraphs)}部分")
    answer = ""
    for i, p in enumerate(draft_paragraphs):
        print(str(i)*80)
        print(f"{datetime.now()} [INFO] 修改第{i+1}/{len(draft_paragraphs)}部分...")
        answer = answer + '\n\n' + p
        # print(f"[{i}/{len(draft_paragraphs)}] Original Answer:\n{answer.replace(newline_char, ' ')}")

        # query = get_query(question, answer)
        print(f"{datetime.now()} [INFO] 生成对应Query...")
        res = run_with_timeout(get_query_wrapper, 3, question, answer)
        if not res:
            print(f"{datetime.now()} [INFO] 生成检索词超时,跳过后续步骤...")
            continue
        else:
            query = res
        print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}")

        print(f"{datetime.now()} [INFO] 获取网页内容...")
        # content = get_content(query)
        res = run_with_timeout(get_content_wrapper, 5, query)
        if not res:
            print(f"{datetime.now()} [INFO] 获取网页内容超时,跳过后续步骤...")
            continue
        else:
            content = res

        for j, c in enumerate(content):
            if  j > 2:
                break
            print(f"{datetime.now()} [INFO] 根据网页内容修改对应答案...[{j}/{min(len(content),3)}]")
            # answer = get_revise_answer(question, answer, c)
            res = run_with_timeout(get_revise_answer_wrapper, 15, question, answer, c)
            if not res:
                print(f"{datetime.now()} [INFO] 修改答案超时,跳过后续步骤...")
                continue
            else:
                diff_html = generate_diff_html(answer, res)
                display(HTML(diff_html))
                answer = res
            print(f"{datetime.now()} [INFO] 答案修改完成[{j}/{min(len(content),3)}]")
        # print(f"[{i}/{len(draft_paragraphs)}] REVISED ANSWER:\n {answer.replace(newline_char, ' ')}")
        # print()
    return draft, answer
    # return answer

page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation"
page_md = """
# RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation

We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated.

Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning.

Feel free to try our demo!

"""

def clear_func():
    return "", "", ""

def set_openai_api_key(api_key):
    if api_key and api_key.startswith("sk-") and len(api_key) > 50:
        os.environ["OPENAI_API_KEY"] = api_key

with gr.Blocks(title = page_title) as demo:
   
    gr.Markdown(page_md)

    with gr.Row():
        chatgpt_box = gr.Textbox(
            label = "ChatGPT",
            placeholder = "Response from ChatGPT with zero-shot chain-of-thought.",
            elem_id = "chatgpt"
        )

    with gr.Row():
        stream_box = gr.Textbox(
            label = "Streaming",
            placeholder = "Interactive response with RAT...",
            elem_id = "stream",
            lines = 10,
            visible = False
        )
    
    with gr.Row():
        rat_box = gr.Textbox(
            label = "RAT",
            placeholder = "Final response with RAT ...",
            elem_id = "rat",
            lines = 6
        )

    with gr.Column(elem_id="instruction_row"):
        with gr.Row():
            instruction_box = gr.Textbox(
                label = "instruction",
                placeholder = "Enter your instruction here",
                lines = 2,
                elem_id="instruction",
                interactive=True,
                visible=True
            )
        with gr.Row():
            model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo", 
                                label='GPT model: ', show_label=True,interactive=True, visible=True) 
            openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
                                        show_label=False, lines=1, type='password')
            
    openai_api_key_textbox.change(set_openai_api_key,
        inputs=[openai_api_key_textbox],
        outputs=[])

    with gr.Row():
        submit_btn = gr.Button(
            value="submit", visible=True, interactive=True
        )
        clear_btn = gr.Button(
            value="clear", visible=True, interactive=True
        )
        regenerate_btn = gr.Button(
            value="regenerate", visible=True, interactive=True
        )

    submit_btn.click(
        fn = rat,
        inputs = [instruction_box],
        outputs = [chatgpt_box, rat_box]
    )

    clear_btn.click(
        fn = clear_func,
        inputs = [],
        outputs = [instruction_box, chatgpt_box, rat_box]
    )

    regenerate_btn.click(
        fn = rat,
        inputs = [instruction_box],
        outputs = [chatgpt_box, rat_box]
    )

    examples = gr.Examples(
        examples=[
            "I went to the supermarket yesterday.", 
            "Helen is a good swimmer."],
        inputs=[instruction_box]
        )

demo.launch(server_name="0.0.0.0", debug=True)