chore: Refactor code to improve vehicle destination handling in calculate_route function
Browse files- kitt/core/model.py +48 -21
- kitt/skills/poi.py +7 -5
- main.py +7 -5
kitt/core/model.py
CHANGED
@@ -13,7 +13,6 @@ from loguru import logger
|
|
13 |
from kitt.skills import vehicle_status
|
14 |
|
15 |
|
16 |
-
|
17 |
class FunctionCall(BaseModel):
|
18 |
arguments: dict
|
19 |
"""
|
@@ -26,6 +25,7 @@ class FunctionCall(BaseModel):
|
|
26 |
name: str
|
27 |
"""The name of the function to call."""
|
28 |
|
|
|
29 |
schema_json = json.loads(FunctionCall.schema_json())
|
30 |
HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
|
31 |
<|im_start|>system
|
@@ -41,7 +41,7 @@ Once you have called a function, results will be fed back to you within <tool_re
|
|
41 |
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
42 |
Analyze the data once you get the results and call another function.
|
43 |
At each iteration please continue adding the your analysis to previous summary.
|
44 |
-
Your final response should directly answer the user query.
|
45 |
|
46 |
|
47 |
Here are the available tools:
|
@@ -53,8 +53,22 @@ If the provided function signatures doesn't have the function you must call, you
|
|
53 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
54 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
Use the following pydantic model json schema for each tool call you will make:
|
60 |
{schema}
|
@@ -83,7 +97,7 @@ HRMS_TEMPLATE_TOOL_RESULT = """
|
|
83 |
<|im_end|>"""
|
84 |
|
85 |
|
86 |
-
def append_message(prompt, h):
|
87 |
if h.type == "human":
|
88 |
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
|
89 |
elif h.type == "ai":
|
@@ -99,11 +113,15 @@ def get_prompt(template, history, tools, schema, car_status=None):
|
|
99 |
car_status = vehicle_status()[0]
|
100 |
|
101 |
# "vehicle_status": vehicle_status_fn()[0]
|
102 |
-
kwargs = {
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
|
106 |
-
|
107 |
if history:
|
108 |
for h in history.messages:
|
109 |
prompt = append_message(prompt, h)
|
@@ -124,7 +142,7 @@ def use_tool(tool_call, tools):
|
|
124 |
|
125 |
def parse_tool_calls(text):
|
126 |
logger.debug(f"Start parsing tool_calls: {text}")
|
127 |
-
pattern = r
|
128 |
|
129 |
if not text.startswith("<tool_call>"):
|
130 |
return [], []
|
@@ -138,7 +156,7 @@ def parse_tool_calls(text):
|
|
138 |
tool_calls.append(tool_call)
|
139 |
except json.JSONDecodeError as e:
|
140 |
errors.append(f"Invalid JSON in tool call: {e}")
|
141 |
-
|
142 |
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
|
143 |
return tool_calls, errors
|
144 |
|
@@ -149,7 +167,7 @@ def process_response(user_query, res, history, tools, depth):
|
|
149 |
tool_calls, errors = parse_tool_calls(res)
|
150 |
# TODO: Handle errors
|
151 |
if not tool_calls:
|
152 |
-
return False
|
153 |
# tool_results = ""
|
154 |
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
|
155 |
for tool_call in tool_calls:
|
@@ -157,7 +175,7 @@ def process_response(user_query, res, history, tools, depth):
|
|
157 |
# Call the function
|
158 |
try:
|
159 |
result = use_tool(tool_call, tools)
|
160 |
-
if
|
161 |
result = result[1]
|
162 |
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
|
163 |
except Exception as e:
|
@@ -169,7 +187,7 @@ def process_response(user_query, res, history, tools, depth):
|
|
169 |
print(f"Tool results: {tool_results}")
|
170 |
tool_call_id = uuid.uuid4().hex
|
171 |
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
|
172 |
-
return True
|
173 |
|
174 |
|
175 |
def run_inference_step(history, tools, schema_json, dry_run=False):
|
@@ -188,12 +206,13 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
|
|
188 |
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
189 |
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
190 |
"raw": True,
|
191 |
-
"options": {
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
197 |
}
|
198 |
|
199 |
if dry_run:
|
@@ -201,7 +220,7 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
|
|
201 |
return "Didn't really run it."
|
202 |
|
203 |
out = ollama.generate(**data)
|
204 |
-
|
205 |
res = out["response"]
|
206 |
|
207 |
return res
|
@@ -213,7 +232,15 @@ def process_query(user_query: str, history: ChatMessageHistory, tools):
|
|
213 |
out = run_inference_step(history, tools, schema_json)
|
214 |
print(f"Inference step result:\n{out}\n------------------\n")
|
215 |
history.add_message(AIMessage(content=out))
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
print(f"This is the answer, no more iterations: {out}")
|
218 |
return out
|
219 |
# Otherwise, tools result is already added to history, we just need to continue the loop.
|
|
|
13 |
from kitt.skills import vehicle_status
|
14 |
|
15 |
|
|
|
16 |
class FunctionCall(BaseModel):
|
17 |
arguments: dict
|
18 |
"""
|
|
|
25 |
name: str
|
26 |
"""The name of the function to call."""
|
27 |
|
28 |
+
|
29 |
schema_json = json.loads(FunctionCall.schema_json())
|
30 |
HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
|
31 |
<|im_start|>system
|
|
|
41 |
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
42 |
Analyze the data once you get the results and call another function.
|
43 |
At each iteration please continue adding the your analysis to previous summary.
|
44 |
+
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
|
45 |
|
46 |
|
47 |
Here are the available tools:
|
|
|
53 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
54 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
55 |
|
56 |
+
Example 1:
|
57 |
+
User: How is the weather today?
|
58 |
+
Assistant:
|
59 |
+
<tool_call>
|
60 |
+
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
61 |
+
</tool_call>
|
62 |
+
|
63 |
+
Example 2:
|
64 |
+
User: Is there a Spa nearby?
|
65 |
+
Assistant:
|
66 |
+
<tool_call>
|
67 |
+
{{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
|
68 |
+
</tool_call>
|
69 |
+
|
70 |
+
When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
|
71 |
+
|
72 |
|
73 |
Use the following pydantic model json schema for each tool call you will make:
|
74 |
{schema}
|
|
|
97 |
<|im_end|>"""
|
98 |
|
99 |
|
100 |
+
def append_message(prompt, h):
|
101 |
if h.type == "human":
|
102 |
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
|
103 |
elif h.type == "ai":
|
|
|
113 |
car_status = vehicle_status()[0]
|
114 |
|
115 |
# "vehicle_status": vehicle_status_fn()[0]
|
116 |
+
kwargs = {
|
117 |
+
"history": history,
|
118 |
+
"schema": schema,
|
119 |
+
"tools": tools,
|
120 |
+
"car_status": car_status,
|
121 |
+
}
|
122 |
|
|
|
123 |
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
|
124 |
+
|
125 |
if history:
|
126 |
for h in history.messages:
|
127 |
prompt = append_message(prompt, h)
|
|
|
142 |
|
143 |
def parse_tool_calls(text):
|
144 |
logger.debug(f"Start parsing tool_calls: {text}")
|
145 |
+
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
146 |
|
147 |
if not text.startswith("<tool_call>"):
|
148 |
return [], []
|
|
|
156 |
tool_calls.append(tool_call)
|
157 |
except json.JSONDecodeError as e:
|
158 |
errors.append(f"Invalid JSON in tool call: {e}")
|
159 |
+
|
160 |
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
|
161 |
return tool_calls, errors
|
162 |
|
|
|
167 |
tool_calls, errors = parse_tool_calls(res)
|
168 |
# TODO: Handle errors
|
169 |
if not tool_calls:
|
170 |
+
return False, tool_calls, errors
|
171 |
# tool_results = ""
|
172 |
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
|
173 |
for tool_call in tool_calls:
|
|
|
175 |
# Call the function
|
176 |
try:
|
177 |
result = use_tool(tool_call, tools)
|
178 |
+
if isinstance(result, tuple):
|
179 |
result = result[1]
|
180 |
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
|
181 |
except Exception as e:
|
|
|
187 |
print(f"Tool results: {tool_results}")
|
188 |
tool_call_id = uuid.uuid4().hex
|
189 |
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
|
190 |
+
return True, tool_calls, errors
|
191 |
|
192 |
|
193 |
def run_inference_step(history, tools, schema_json, dry_run=False):
|
|
|
206 |
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
207 |
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
208 |
"raw": True,
|
209 |
+
"options": {
|
210 |
+
"temperature": 0.8,
|
211 |
+
# "max_tokens": 1500,
|
212 |
+
"num_predict": 1500,
|
213 |
+
# "num_predict": 1500,
|
214 |
+
# "max_tokens": 1500,
|
215 |
+
},
|
216 |
}
|
217 |
|
218 |
if dry_run:
|
|
|
220 |
return "Didn't really run it."
|
221 |
|
222 |
out = ollama.generate(**data)
|
223 |
+
logger.debug(f"Response from model: {out}")
|
224 |
res = out["response"]
|
225 |
|
226 |
return res
|
|
|
232 |
out = run_inference_step(history, tools, schema_json)
|
233 |
print(f"Inference step result:\n{out}\n------------------\n")
|
234 |
history.add_message(AIMessage(content=out))
|
235 |
+
to_continue, tool_calls, errors = process_response(
|
236 |
+
user_query, out, history, tools, depth
|
237 |
+
)
|
238 |
+
if errors:
|
239 |
+
history.add_message(
|
240 |
+
AIMessage(content=f"Errors in tool calls: {errors}")
|
241 |
+
)
|
242 |
+
|
243 |
+
if not to_continue:
|
244 |
print(f"This is the answer, no more iterations: {out}")
|
245 |
return out
|
246 |
# Otherwise, tools result is already added to history, we just need to continue the loop.
|
kitt/skills/poi.py
CHANGED
@@ -4,7 +4,7 @@ from .common import config, vehicle
|
|
4 |
|
5 |
|
6 |
# Select coordinates at equal distance, including the last one
|
7 |
-
def
|
8 |
n = len(coords)
|
9 |
selected_coords = []
|
10 |
interval = max((n - 1) / (number_of_points - 1), 1)
|
@@ -18,8 +18,10 @@ def select_equally_spaced_coordinates(coords, number_of_points=10):
|
|
18 |
|
19 |
def search_points_of_interests(search_query="french restaurant"):
|
20 |
"""
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
"""
|
24 |
|
25 |
# Extract the latitude and longitude of the vehicle
|
@@ -103,7 +105,7 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
|
|
103 |
# The API endpoint for searching along a route
|
104 |
url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=360&limit=20&sortBy=detourTime"
|
105 |
|
106 |
-
points =
|
107 |
|
108 |
# The data payload
|
109 |
payload = {
|
@@ -140,4 +142,4 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
|
|
140 |
+ f" \n{name} at {address} would require a detour of {int(detour_time/60)} minutes."
|
141 |
)
|
142 |
|
143 |
-
return answer
|
|
|
4 |
|
5 |
|
6 |
# Select coordinates at equal distance, including the last one
|
7 |
+
def _select_equally_spaced_coordinates(coords, number_of_points=10):
|
8 |
n = len(coords)
|
9 |
selected_coords = []
|
10 |
interval = max((n - 1) / (number_of_points - 1), 1)
|
|
|
18 |
|
19 |
def search_points_of_interests(search_query="french restaurant"):
|
20 |
"""
|
21 |
+
Get some of the closest points of interest matching the query.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
search_query (string): Required. Describing the type of point of interest depending on what the user wants to do. Make sure to include the type of POI you are looking for. For example italian restaurant, grocery shop, etc.
|
25 |
"""
|
26 |
|
27 |
# Extract the latitude and longitude of the vehicle
|
|
|
105 |
# The API endpoint for searching along a route
|
106 |
url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=360&limit=20&sortBy=detourTime"
|
107 |
|
108 |
+
points = _select_equally_spaced_coordinates(points, number_of_points=20)
|
109 |
|
110 |
# The data payload
|
111 |
payload = {
|
|
|
142 |
+ f" \n{name} at {address} would require a detour of {int(detour_time/60)} minutes."
|
143 |
)
|
144 |
|
145 |
+
return answer, data["results"][:5]
|
main.py
CHANGED
@@ -25,7 +25,7 @@ from kitt.skills import (
|
|
25 |
search_along_route_w_coordinates,
|
26 |
do_anything_else,
|
27 |
date_time_info,
|
28 |
-
get_weather_current_location
|
29 |
)
|
30 |
from kitt.skills import extract_func_args
|
31 |
from kitt.core import voice_options, tts_gradio
|
@@ -192,7 +192,6 @@ def run_llama3_model(query, voice_character):
|
|
192 |
|
193 |
|
194 |
def run_model(query, voice_character, state):
|
195 |
-
|
196 |
model = state.get("model", "nexusraven")
|
197 |
query = query.strip().replace("'", "")
|
198 |
print("Query: ", query)
|
@@ -224,8 +223,9 @@ def update_vehicle_status(trip_progress, origin, destination):
|
|
224 |
print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
|
225 |
vehicle.location_coordinates = new_coords
|
226 |
vehicle.location = ""
|
227 |
-
|
228 |
-
|
|
|
229 |
return vehicle.model_dump_json(), plot
|
230 |
|
231 |
|
@@ -373,7 +373,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
373 |
|
374 |
# Set the vehicle status based on the trip progress
|
375 |
trip_progress.release(
|
376 |
-
fn=update_vehicle_status,
|
|
|
|
|
377 |
)
|
378 |
|
379 |
# Save and transcribe the audio
|
|
|
25 |
search_along_route_w_coordinates,
|
26 |
do_anything_else,
|
27 |
date_time_info,
|
28 |
+
get_weather_current_location,
|
29 |
)
|
30 |
from kitt.skills import extract_func_args
|
31 |
from kitt.core import voice_options, tts_gradio
|
|
|
192 |
|
193 |
|
194 |
def run_model(query, voice_character, state):
|
|
|
195 |
model = state.get("model", "nexusraven")
|
196 |
query = query.strip().replace("'", "")
|
197 |
print("Query: ", query)
|
|
|
223 |
print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
|
224 |
vehicle.location_coordinates = new_coords
|
225 |
vehicle.location = ""
|
226 |
+
plot = kitt_utils.plot_route(
|
227 |
+
global_context["route_points"], vehicle=vehicle.location_coordinates
|
228 |
+
)
|
229 |
return vehicle.model_dump_json(), plot
|
230 |
|
231 |
|
|
|
373 |
|
374 |
# Set the vehicle status based on the trip progress
|
375 |
trip_progress.release(
|
376 |
+
fn=update_vehicle_status,
|
377 |
+
inputs=[trip_progress, origin, destination],
|
378 |
+
outputs=[vehicle_status, map_plot],
|
379 |
)
|
380 |
|
381 |
# Save and transcribe the audio
|