File size: 4,066 Bytes
2319518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
from typing import Dict, Iterator, List, Optional

from qwen_agent.actions.base import Action
from qwen_agent.tools import call_plugin

TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""

PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:

{tool_descs}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {query}"""


def _build_react_instruction(query: str, functions: List[Dict]):
    tool_descs = []
    tool_names = []
    for info in functions:
        tool_descs.append(
            TOOL_DESC.format(
                name_for_model=info['name_for_model'],
                name_for_human=info['name_for_human'],
                description_for_model=info['description_for_model'],
                parameters=json.dumps(info['parameters'], ensure_ascii=False),
            ))
        tool_names.append(info['name_for_model'])
    tool_descs = '\n\n'.join(tool_descs)
    tool_names = ','.join(tool_names)

    prompt = PROMPT_REACT.format(tool_descs=tool_descs,
                                 tool_names=tool_names,
                                 query=query)
    return prompt


def _parse_last_action(text):
    plugin_name, plugin_args = '', ''
    i = text.rfind('\nAction:')
    j = text.rfind('\nAction Input:')
    k = text.rfind('\nObservation:')
    if 0 <= i < j:  # If the text has `Action` and `Action input`,
        if k < j:  # but does not contain `Observation`,
            # then it is likely that `Observation` is ommited by the LLM,
            # because the output text may have discarded the stop word.
            text = text.rstrip() + '\nObservation:'  # Add it back.
        k = text.rfind('\nObservation:')
        plugin_name = text[i + len('\nAction:'):j].strip()
        plugin_args = text[j + len('\nAction Input:'):k].strip()
        text = text[:k]  # Discard '\nObservation:'.
    return plugin_name, plugin_args, text


# TODO: When to put an parameter (such as history) in __init__()? When to put it in run()?
class ReAct(Action):

    def _run(self,
             user_request,
             functions: List[Dict] = None,
             history: Optional[List[Dict]] = None,
             lang: str = 'en') -> Iterator[str]:
        functions = functions or []
        prompt = _build_react_instruction(user_request, functions)

        messages = []
        if history:
            assert history[-1][
                'role'] != 'user', 'The history should not include the latest user query.'
            messages.extend(history)
        messages.append({'role': 'user', 'content': prompt})

        max_turn = 5
        while True and max_turn > 0:
            max_turn -= 1
            output = self.llm.chat(
                messages=messages,
                stream=False,  # TODO:
                stop=['Observation:', 'Observation:\n'],
            )
            action, action_input, output = _parse_last_action(output)
            if messages[-1]['content'].endswith('\nThought:'):
                if not output.startswith(' '):
                    output = ' ' + output
            else:
                if not output.startswith('\n'):
                    output = '\n' + output
            yield output
            if action:
                observation = call_plugin(action, action_input)
                observation = f'\nObservation: {observation}\nThought:'
                yield observation
                messages[-1]['content'] += output + observation
            else:
                break