KITT / kitt /core /model.py
sasan's picture
chore: Add code interpreter skill and update vehicle status template
962f893
raw
history blame
10.4 kB
import json
import re
import uuid
from langchain.memory import ChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.utils.function_calling import convert_to_openai_function
import ollama
from ollama import Client
from pydantic import BaseModel
from loguru import logger
from kitt.skills import vehicle_status
class FunctionCall(BaseModel):
arguments: dict
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
schema_json = json.loads(FunctionCall.schema_json())
HRMS_SYSTEM_PROMPT = """<|im_start|>system
You are a function calling AI agent with self-recursion.
You can call only one function at a time and analyse data you get from function response.
You are provided with function signatures within <tools></tools> XML tags.
You may use agentic frameworks for reasoning and planning to help with user query.
Please call a function and wait for function results to be provided to you in the next iteration.
Don't make assumptions about what values to plug into function arguments.
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
Analyze the data once you get the results and call another function.
At each iteration please continue adding the your analysis to previous summary.
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
Here are the available tools:
<tools> {tools} </tools>
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
<tool_call>
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
</tool_call>
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
Example 1:
User: How is the weather today?
Assistant:
<tool_call>
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
</tool_call>
Example 2:
User: Is there a Spa nearby?
Assistant:
<tool_call>
{{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
</tool_call>
Example 3:
User: How long will it take to get to the destination?
Assistant:
<tool_call>
{{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
Always assume user wants to travel by car.
Use the following pydantic model json schema for each tool call you will make:
{schema}
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
If you plan to continue with analysis, always call another function.
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{{"arguments": <args-dict>, "name": <function-name>}}
</tool_call>
<|im_end|>"""
AI_PREAMBLE = """
<|im_start|>assistant
"""
HRMS_TEMPLATE_USER = """
<|im_start|>user
{user_input}<|im_end|>"""
HRMS_TEMPLATE_ASSISTANT = """
<|im_start|>assistant
{assistant_response}<|im_end|>"""
HRMS_TEMPLATE_TOOL_RESULT = """
<|im_start|>tool
{result}
<|im_end|>"""
def append_message(prompt, h):
if h.type == "human":
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
elif h.type == "ai":
prompt += HRMS_TEMPLATE_ASSISTANT.format(assistant_response=h.content)
elif h.type == "tool":
prompt += HRMS_TEMPLATE_TOOL_RESULT.format(result=h.content)
return prompt
def get_prompt(template, history, tools, schema, car_status=None):
if not car_status:
# car_status = vehicle.dict()
car_status = vehicle_status()[0]
# "vehicle_status": vehicle_status_fn()[0]
kwargs = {
"history": history,
"schema": schema,
"tools": tools,
"car_status": car_status,
}
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
if history:
for h in history.messages:
prompt = append_message(prompt, h)
# if input:
# prompt += USER_QUERY_TEMPLATE.format(user_input=input)
return prompt
def use_tool(tool_call, tools):
func_name = tool_call["name"]
kwargs = tool_call["arguments"]
for tool in tools:
if tool.name == func_name:
return tool.invoke(input=kwargs)
return None
def parse_tool_calls(text):
logger.debug(f"Start parsing tool_calls: {text}")
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
if not text.startswith("<tool_call>"):
if "<tool_call>" in text:
raise ValueError("<text_and_tool_call>")
return [], []
matches = re.findall(pattern, text, re.DOTALL)
tool_calls = []
errors = []
for match in matches:
try:
tool_call = json.loads(match)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
errors.append(f"Invalid JSON in tool call: {e}")
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
return tool_calls, errors
def process_response(user_query, res, history, tools, depth):
"""Returns True if the response contains tool calls, False otherwise."""
logger.debug(f"Processing response: {res}")
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
tool_call_id = uuid.uuid4().hex
try:
tool_calls, errors = parse_tool_calls(res)
except ValueError as e:
if "<text_and_tool_call>" in str(e):
tool_results += f"A mix of text and tool_call was found, you must either answer the query in a short sentence or use tool_call not both. Try again, this time only using tool_call."
history.add_message(
ToolMessage(content=tool_results, tool_call_id=tool_call_id)
)
return True, [], []
# TODO: Handle errors
if not tool_calls:
return False, tool_calls, errors
# tool_results = ""
for tool_call in tool_calls:
# TODO: Extra Validation
# Call the function
try:
result = use_tool(tool_call, tools)
if isinstance(result, tuple):
result = result[1]
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
except Exception as e:
print(e)
# Currently only to mimic OpneAI's behavior
# But it could be used for tracking function calls
tool_results = tool_results.strip()
print(f"Tool results: {tool_results}")
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
return True, tool_calls, errors
def run_inference_step(depth, history, tools, schema_json, dry_run=False):
# If we decide to call a function, we need to generate the prompt for the model
# based on the history of the conversation so far.
# not break the loop
openai_tools = [convert_to_openai_function(tool) for tool in tools]
prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
data = {
"prompt": prompt
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
+ AI_PREAMBLE,
# "streaming": False,
# "model": "smangrul/llama-3-8b-instruct-function-calling",
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
# "model": "interstellarninja/hermes-2-pro-llama-3-8b",
"model": "dolphin-llama3:8b",
# "model": "dolphin-llama3:70b",
"raw": True,
"options": {
"temperature": 0.8,
# "max_tokens": 1500,
"num_predict": 1500,
"mirostat": 1,
# "mirostat_tau": 2,
"repeat_penalty": 1.5,
"top_k": 25,
"top_p": 0.5,
# "num_predict": 1500,
# "max_tokens": 1500,
},
}
if dry_run:
print(prompt + AI_PREAMBLE)
return "Didn't really run it."
client = Client(host='http://localhost:11444')
# out = ollama.generate(**data)
out = client.generate(**data)
logger.debug(f"Response from model: {out}")
res = out["response"]
return res
def process_query(user_query: str, history: ChatMessageHistory, tools):
# Add vehicle status to the history
user_query_status = (
f"Given that:\n{vehicle_status()[0]}\nAnswer the following:\n{user_query}"
)
history.add_message(HumanMessage(content=user_query_status))
for depth in range(10):
out = run_inference_step(depth, history, tools, schema_json)
print(f"Inference step result:\n{out}\n------------------\n")
history.add_message(AIMessage(content=out))
to_continue, tool_calls, errors = process_response(
user_query, out, history, tools, depth
)
if errors:
history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))
if not to_continue:
print(f"This is the answer, no more iterations: {out}")
return out
# Otherwise, tools result is already added to history, we just need to continue the loop.
# If we get here something went wrong.
history.add_message(
AIMessage(content="Sorry, I am not sure how to help you with that.")
)
return "Sorry, I am not sure how to help you with that."