File size: 4,561 Bytes
0f04201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import uuid
import json
import re
from loguru import logger

def use_tool(tool_call, tools):
    func_name = tool_call["name"]
    kwargs = tool_call["arguments"]
    for tool in tools:
        if tool.name == func_name:
            return tool.invoke(input=kwargs)
    raise ValueError(f"Tool {func_name} not found.")


def parse_tool_calls(text):
    logger.debug(f"Start parsing tool_calls: {text}")
    pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"

    if not text.startswith("<tool_call>"):
        if "<tool_call>" in text:
            raise ValueError("<text_and_tool_call>")

        if "<tool_response>" in text:
            raise ValueError("<tool_response>")
        return [], []

    matches = re.findall(pattern, text, re.DOTALL)
    tool_calls = []
    errors = []
    for match in matches:
        try:
            tool_call = json.loads(match)
            tool_calls.append(tool_call)
        except json.JSONDecodeError as e:
            errors.append(f"Invalid JSON in tool call: {e}")

    logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
    return tool_calls, errors

def process_response(user_query, res, history, tools, depth):
    """Returns True if the response contains tool calls, False otherwise."""
    logger.debug(f"Processing response: {res}")
    tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
    tool_call_id = uuid.uuid4().hex
    try:
        tool_calls, errors = parse_tool_calls(res)
    except ValueError as e:
        if "<text_and_tool_call>" in str(e):
            tool_results += "<tool_response>If you need to call a tool your response must be wrapped in <tool_call></tool_call>. Try again, you are great.</tool_response>"
            history.add_message(
                ToolMessage(content=tool_results, tool_call_id=tool_call_id)
            )
            return True, [], []
        if "<tool_response>" in str(e):
            tool_results += "<tool_response>Tool results are not allowed in the response.</tool_response>"
            history.add_message(
                ToolMessage(content=tool_results, tool_call_id=tool_call_id)
            )
            return True, [], []
    # TODO: Handle errors
    if not tool_calls:
        logger.debug("No tool calls found in response.")
        return False, tool_calls, errors
    # tool_results = ""

    for tool_call in tool_calls:
        # TODO: Extra Validation
        # Call the function
        try:
            result = use_tool(tool_call, tools)
            logger.debug(f"Tool call {tool_call} result: {result}")
            if isinstance(result, tuple):
                result = result[1]
            tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
        except Exception as e:
            logger.error(f"Error calling tool: {e}")
    # Currently only to mimic OpneAI's behavior
    # But it could be used for tracking function calls

    tool_results = tool_results.strip()
    print(f"Tool results: {tool_results}")
    history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
    return True, tool_calls, errors


def process_query(
    user_query: str,
    history: ChatMessageHistory,
    user_preferences,
    tools,
    backend="ollama",
):
    # Add vehicle status to the history
    user_query_status = f"consider the vehicle status:\n{vehicle_status()[0]}\nwhen responding to the following query:\n{user_query}"
    history.add_message(HumanMessage(content=user_query_status))
    for depth in range(10):
        # out = run_inference_step(depth, history, tools, schema_json)
        out = run_inference_step(
            depth,
            history,
            tools,
            schema_json,
            user_preferences=user_preferences,
            backend=backend,
        )
        logger.info(f"Inference step result:\n{out}")
        history.add_message(AIMessage(content=out))
        to_continue, tool_calls, errors = process_response(
            user_query, out, history, tools, depth
        )
        if errors:
            history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))

        if not to_continue:
            print(f"This is the answer, no more iterations: {out}")
            return out
        # Otherwise, tools result is already added to history, we just need to continue the loop.
    # If we get here something went wrong.
    history.add_message(
        AIMessage(content="Sorry, I am not sure how to help you with that.")
    )
    return "Sorry, I am not sure how to help you with that."