sasan commited on
Commit
bd669ec
·
1 Parent(s): 60ee11d

chore: Refactor code to improve vehicle destination handling in calculate_route function

Browse files
Files changed (3) hide show
  1. kitt/core/model.py +48 -21
  2. kitt/skills/poi.py +7 -5
  3. main.py +7 -5
kitt/core/model.py CHANGED
@@ -13,7 +13,6 @@ from loguru import logger
13
  from kitt.skills import vehicle_status
14
 
15
 
16
-
17
  class FunctionCall(BaseModel):
18
  arguments: dict
19
  """
@@ -26,6 +25,7 @@ class FunctionCall(BaseModel):
26
  name: str
27
  """The name of the function to call."""
28
 
 
29
  schema_json = json.loads(FunctionCall.schema_json())
30
  HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
31
  <|im_start|>system
@@ -41,7 +41,7 @@ Once you have called a function, results will be fed back to you within <tool_re
41
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
42
  Analyze the data once you get the results and call another function.
43
  At each iteration please continue adding the your analysis to previous summary.
44
- Your final response should directly answer the user query.
45
 
46
 
47
  Here are the available tools:
@@ -53,8 +53,22 @@ If the provided function signatures doesn't have the function you must call, you
53
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
54
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
55
 
56
- When asked for the weather, lookup the weather for the current location of the car. Unless the user provides a location, then use that location.
57
- If asked about points of interest, use the tools available to you. Do not make up points of interest.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  Use the following pydantic model json schema for each tool call you will make:
60
  {schema}
@@ -83,7 +97,7 @@ HRMS_TEMPLATE_TOOL_RESULT = """
83
  <|im_end|>"""
84
 
85
 
86
- def append_message(prompt, h):
87
  if h.type == "human":
88
  prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
89
  elif h.type == "ai":
@@ -99,11 +113,15 @@ def get_prompt(template, history, tools, schema, car_status=None):
99
  car_status = vehicle_status()[0]
100
 
101
  # "vehicle_status": vehicle_status_fn()[0]
102
- kwargs = {"history": history, "schema": schema, "tools": tools, "car_status": car_status}
 
 
 
 
 
103
 
104
-
105
  prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
106
-
107
  if history:
108
  for h in history.messages:
109
  prompt = append_message(prompt, h)
@@ -124,7 +142,7 @@ def use_tool(tool_call, tools):
124
 
125
  def parse_tool_calls(text):
126
  logger.debug(f"Start parsing tool_calls: {text}")
127
- pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
128
 
129
  if not text.startswith("<tool_call>"):
130
  return [], []
@@ -138,7 +156,7 @@ def parse_tool_calls(text):
138
  tool_calls.append(tool_call)
139
  except json.JSONDecodeError as e:
140
  errors.append(f"Invalid JSON in tool call: {e}")
141
-
142
  logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
143
  return tool_calls, errors
144
 
@@ -149,7 +167,7 @@ def process_response(user_query, res, history, tools, depth):
149
  tool_calls, errors = parse_tool_calls(res)
150
  # TODO: Handle errors
151
  if not tool_calls:
152
- return False
153
  # tool_results = ""
154
  tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
155
  for tool_call in tool_calls:
@@ -157,7 +175,7 @@ def process_response(user_query, res, history, tools, depth):
157
  # Call the function
158
  try:
159
  result = use_tool(tool_call, tools)
160
- if type(result) == tuple:
161
  result = result[1]
162
  tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
163
  except Exception as e:
@@ -169,7 +187,7 @@ def process_response(user_query, res, history, tools, depth):
169
  print(f"Tool results: {tool_results}")
170
  tool_call_id = uuid.uuid4().hex
171
  history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
172
- return True
173
 
174
 
175
  def run_inference_step(history, tools, schema_json, dry_run=False):
@@ -188,12 +206,13 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
188
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
189
  "model": "interstellarninja/hermes-2-pro-llama-3-8b",
190
  "raw": True,
191
- "options": {"temperature": 0.8,
192
- # "max_tokens": 1500,
193
- "num_predict": 1500,
194
- # "num_predict": 1500,
195
- # "max_tokens": 1500,
196
- }
 
197
  }
198
 
199
  if dry_run:
@@ -201,7 +220,7 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
201
  return "Didn't really run it."
202
 
203
  out = ollama.generate(**data)
204
-
205
  res = out["response"]
206
 
207
  return res
@@ -213,7 +232,15 @@ def process_query(user_query: str, history: ChatMessageHistory, tools):
213
  out = run_inference_step(history, tools, schema_json)
214
  print(f"Inference step result:\n{out}\n------------------\n")
215
  history.add_message(AIMessage(content=out))
216
- if not process_response(user_query, out, history, tools, depth):
 
 
 
 
 
 
 
 
217
  print(f"This is the answer, no more iterations: {out}")
218
  return out
219
  # Otherwise, tools result is already added to history, we just need to continue the loop.
 
13
  from kitt.skills import vehicle_status
14
 
15
 
 
16
  class FunctionCall(BaseModel):
17
  arguments: dict
18
  """
 
25
  name: str
26
  """The name of the function to call."""
27
 
28
+
29
  schema_json = json.loads(FunctionCall.schema_json())
30
  HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
31
  <|im_start|>system
 
41
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
42
  Analyze the data once you get the results and call another function.
43
  At each iteration please continue adding the your analysis to previous summary.
44
+ Your final response should directly answer the user query. Don't tell what you are doing, just do it.
45
 
46
 
47
  Here are the available tools:
 
53
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
54
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
55
 
56
+ Example 1:
57
+ User: How is the weather today?
58
+ Assistant:
59
+ <tool_call>
60
+ {{"arguments": {{"location": ""}}, "name": "get_weather"}}
61
+ </tool_call>
62
+
63
+ Example 2:
64
+ User: Is there a Spa nearby?
65
+ Assistant:
66
+ <tool_call>
67
+ {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
68
+ </tool_call>
69
+
70
+ When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
71
+
72
 
73
  Use the following pydantic model json schema for each tool call you will make:
74
  {schema}
 
97
  <|im_end|>"""
98
 
99
 
100
+ def append_message(prompt, h):
101
  if h.type == "human":
102
  prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
103
  elif h.type == "ai":
 
113
  car_status = vehicle_status()[0]
114
 
115
  # "vehicle_status": vehicle_status_fn()[0]
116
+ kwargs = {
117
+ "history": history,
118
+ "schema": schema,
119
+ "tools": tools,
120
+ "car_status": car_status,
121
+ }
122
 
 
123
  prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
124
+
125
  if history:
126
  for h in history.messages:
127
  prompt = append_message(prompt, h)
 
142
 
143
  def parse_tool_calls(text):
144
  logger.debug(f"Start parsing tool_calls: {text}")
145
+ pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
146
 
147
  if not text.startswith("<tool_call>"):
148
  return [], []
 
156
  tool_calls.append(tool_call)
157
  except json.JSONDecodeError as e:
158
  errors.append(f"Invalid JSON in tool call: {e}")
159
+
160
  logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
161
  return tool_calls, errors
162
 
 
167
  tool_calls, errors = parse_tool_calls(res)
168
  # TODO: Handle errors
169
  if not tool_calls:
170
+ return False, tool_calls, errors
171
  # tool_results = ""
172
  tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
173
  for tool_call in tool_calls:
 
175
  # Call the function
176
  try:
177
  result = use_tool(tool_call, tools)
178
+ if isinstance(result, tuple):
179
  result = result[1]
180
  tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
181
  except Exception as e:
 
187
  print(f"Tool results: {tool_results}")
188
  tool_call_id = uuid.uuid4().hex
189
  history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
190
+ return True, tool_calls, errors
191
 
192
 
193
  def run_inference_step(history, tools, schema_json, dry_run=False):
 
206
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
207
  "model": "interstellarninja/hermes-2-pro-llama-3-8b",
208
  "raw": True,
209
+ "options": {
210
+ "temperature": 0.8,
211
+ # "max_tokens": 1500,
212
+ "num_predict": 1500,
213
+ # "num_predict": 1500,
214
+ # "max_tokens": 1500,
215
+ },
216
  }
217
 
218
  if dry_run:
 
220
  return "Didn't really run it."
221
 
222
  out = ollama.generate(**data)
223
+ logger.debug(f"Response from model: {out}")
224
  res = out["response"]
225
 
226
  return res
 
232
  out = run_inference_step(history, tools, schema_json)
233
  print(f"Inference step result:\n{out}\n------------------\n")
234
  history.add_message(AIMessage(content=out))
235
+ to_continue, tool_calls, errors = process_response(
236
+ user_query, out, history, tools, depth
237
+ )
238
+ if errors:
239
+ history.add_message(
240
+ AIMessage(content=f"Errors in tool calls: {errors}")
241
+ )
242
+
243
+ if not to_continue:
244
  print(f"This is the answer, no more iterations: {out}")
245
  return out
246
  # Otherwise, tools result is already added to history, we just need to continue the loop.
kitt/skills/poi.py CHANGED
@@ -4,7 +4,7 @@ from .common import config, vehicle
4
 
5
 
6
  # Select coordinates at equal distance, including the last one
7
- def select_equally_spaced_coordinates(coords, number_of_points=10):
8
  n = len(coords)
9
  selected_coords = []
10
  interval = max((n - 1) / (number_of_points - 1), 1)
@@ -18,8 +18,10 @@ def select_equally_spaced_coordinates(coords, number_of_points=10):
18
 
19
  def search_points_of_interests(search_query="french restaurant"):
20
  """
21
- Return some of the closest points of interest matching the query.
22
- :param search_query (string): Required. Describing the type of point of interest depending on what the user wants to do. Make sure to include the type of POI you are looking for. For example italian restaurant, grocery shop, etc.
 
 
23
  """
24
 
25
  # Extract the latitude and longitude of the vehicle
@@ -103,7 +105,7 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
103
  # The API endpoint for searching along a route
104
  url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=360&limit=20&sortBy=detourTime"
105
 
106
- points = select_equally_spaced_coordinates(points, number_of_points=20)
107
 
108
  # The data payload
109
  payload = {
@@ -140,4 +142,4 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
140
  + f" \n{name} at {address} would require a detour of {int(detour_time/60)} minutes."
141
  )
142
 
143
- return answer
 
4
 
5
 
6
  # Select coordinates at equal distance, including the last one
7
+ def _select_equally_spaced_coordinates(coords, number_of_points=10):
8
  n = len(coords)
9
  selected_coords = []
10
  interval = max((n - 1) / (number_of_points - 1), 1)
 
18
 
19
  def search_points_of_interests(search_query="french restaurant"):
20
  """
21
+ Get some of the closest points of interest matching the query.
22
+
23
+ Args:
24
+ search_query (string): Required. Describing the type of point of interest depending on what the user wants to do. Make sure to include the type of POI you are looking for. For example italian restaurant, grocery shop, etc.
25
  """
26
 
27
  # Extract the latitude and longitude of the vehicle
 
105
  # The API endpoint for searching along a route
106
  url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=360&limit=20&sortBy=detourTime"
107
 
108
+ points = _select_equally_spaced_coordinates(points, number_of_points=20)
109
 
110
  # The data payload
111
  payload = {
 
142
  + f" \n{name} at {address} would require a detour of {int(detour_time/60)} minutes."
143
  )
144
 
145
+ return answer, data["results"][:5]
main.py CHANGED
@@ -25,7 +25,7 @@ from kitt.skills import (
25
  search_along_route_w_coordinates,
26
  do_anything_else,
27
  date_time_info,
28
- get_weather_current_location
29
  )
30
  from kitt.skills import extract_func_args
31
  from kitt.core import voice_options, tts_gradio
@@ -192,7 +192,6 @@ def run_llama3_model(query, voice_character):
192
 
193
 
194
  def run_model(query, voice_character, state):
195
-
196
  model = state.get("model", "nexusraven")
197
  query = query.strip().replace("'", "")
198
  print("Query: ", query)
@@ -224,8 +223,9 @@ def update_vehicle_status(trip_progress, origin, destination):
224
  print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
225
  vehicle.location_coordinates = new_coords
226
  vehicle.location = ""
227
-
228
- plot = kitt_utils.plot_route(global_context["route_points"], vehicle=vehicle.location_coordinates)
 
229
  return vehicle.model_dump_json(), plot
230
 
231
 
@@ -373,7 +373,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
373
 
374
  # Set the vehicle status based on the trip progress
375
  trip_progress.release(
376
- fn=update_vehicle_status, inputs=[trip_progress, origin, destination], outputs=[vehicle_status, map_plot]
 
 
377
  )
378
 
379
  # Save and transcribe the audio
 
25
  search_along_route_w_coordinates,
26
  do_anything_else,
27
  date_time_info,
28
+ get_weather_current_location,
29
  )
30
  from kitt.skills import extract_func_args
31
  from kitt.core import voice_options, tts_gradio
 
192
 
193
 
194
  def run_model(query, voice_character, state):
 
195
  model = state.get("model", "nexusraven")
196
  query = query.strip().replace("'", "")
197
  print("Query: ", query)
 
223
  print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
224
  vehicle.location_coordinates = new_coords
225
  vehicle.location = ""
226
+ plot = kitt_utils.plot_route(
227
+ global_context["route_points"], vehicle=vehicle.location_coordinates
228
+ )
229
  return vehicle.model_dump_json(), plot
230
 
231
 
 
373
 
374
  # Set the vehicle status based on the trip progress
375
  trip_progress.release(
376
+ fn=update_vehicle_status,
377
+ inputs=[trip_progress, origin, destination],
378
+ outputs=[vehicle_status, map_plot],
379
  )
380
 
381
  # Save and transcribe the audio