sasan commited on
Commit
962f893
1 Parent(s): bd669ec

chore: Add code interpreter skill and update vehicle status template

Browse files
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "RUN KITT",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "main.py",
12
+ "console": "integratedTerminal"
13
+ }
14
+ ]
15
+ }
kitt/core/model.py CHANGED
@@ -6,6 +6,7 @@ from langchain.memory import ChatMessageHistory
6
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
7
  from langchain_core.utils.function_calling import convert_to_openai_function
8
  import ollama
 
9
  from pydantic import BaseModel
10
  from loguru import logger
11
 
@@ -27,12 +28,10 @@ class FunctionCall(BaseModel):
27
 
28
 
29
  schema_json = json.loads(FunctionCall.schema_json())
30
- HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
31
- <|im_start|>system
32
  You are a function calling AI agent with self-recursion.
33
  You can call only one function at a time and analyse data you get from function response.
34
  You are provided with function signatures within <tools></tools> XML tags.
35
- {car_status}
36
 
37
  You may use agentic frameworks for reasoning and planning to help with user query.
38
  Please call a function and wait for function results to be provided to you in the next iteration.
@@ -67,8 +66,14 @@ Assistant:
67
  {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
68
  </tool_call>
69
 
70
- When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
 
 
 
 
71
 
 
 
72
 
73
  Use the following pydantic model json schema for each tool call you will make:
74
  {schema}
@@ -145,6 +150,8 @@ def parse_tool_calls(text):
145
  pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
146
 
147
  if not text.startswith("<tool_call>"):
 
 
148
  return [], []
149
 
150
  matches = re.findall(pattern, text, re.DOTALL)
@@ -164,12 +171,22 @@ def parse_tool_calls(text):
164
  def process_response(user_query, res, history, tools, depth):
165
  """Returns True if the response contains tool calls, False otherwise."""
166
  logger.debug(f"Processing response: {res}")
167
- tool_calls, errors = parse_tool_calls(res)
 
 
 
 
 
 
 
 
 
 
168
  # TODO: Handle errors
169
  if not tool_calls:
170
  return False, tool_calls, errors
171
  # tool_results = ""
172
- tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
173
  for tool_call in tool_calls:
174
  # TODO: Extra Validation
175
  # Call the function
@@ -185,12 +202,11 @@ def process_response(user_query, res, history, tools, depth):
185
 
186
  tool_results = tool_results.strip()
187
  print(f"Tool results: {tool_results}")
188
- tool_call_id = uuid.uuid4().hex
189
  history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
190
  return True, tool_calls, errors
191
 
192
 
193
- def run_inference_step(history, tools, schema_json, dry_run=False):
194
  # If we decide to call a function, we need to generate the prompt for the model
195
  # based on the history of the conversation so far.
196
  # not break the loop
@@ -199,17 +215,26 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
199
  print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
200
 
201
  data = {
202
- "prompt": prompt + AI_PREAMBLE,
 
 
203
  # "streaming": False,
204
  # "model": "smangrul/llama-3-8b-instruct-function-calling",
205
  # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
206
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
207
- "model": "interstellarninja/hermes-2-pro-llama-3-8b",
 
 
208
  "raw": True,
209
  "options": {
210
  "temperature": 0.8,
211
  # "max_tokens": 1500,
212
  "num_predict": 1500,
 
 
 
 
 
213
  # "num_predict": 1500,
214
  # "max_tokens": 1500,
215
  },
@@ -218,8 +243,10 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
218
  if dry_run:
219
  print(prompt + AI_PREAMBLE)
220
  return "Didn't really run it."
221
-
222
- out = ollama.generate(**data)
 
 
223
  logger.debug(f"Response from model: {out}")
224
  res = out["response"]
225
 
@@ -227,18 +254,20 @@ def run_inference_step(history, tools, schema_json, dry_run=False):
227
 
228
 
229
  def process_query(user_query: str, history: ChatMessageHistory, tools):
230
- history.add_message(HumanMessage(content=user_query))
 
 
 
 
231
  for depth in range(10):
232
- out = run_inference_step(history, tools, schema_json)
233
  print(f"Inference step result:\n{out}\n------------------\n")
234
  history.add_message(AIMessage(content=out))
235
  to_continue, tool_calls, errors = process_response(
236
  user_query, out, history, tools, depth
237
  )
238
  if errors:
239
- history.add_message(
240
- AIMessage(content=f"Errors in tool calls: {errors}")
241
- )
242
 
243
  if not to_continue:
244
  print(f"This is the answer, no more iterations: {out}")
 
6
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
7
  from langchain_core.utils.function_calling import convert_to_openai_function
8
  import ollama
9
+ from ollama import Client
10
  from pydantic import BaseModel
11
  from loguru import logger
12
 
 
28
 
29
 
30
  schema_json = json.loads(FunctionCall.schema_json())
31
+ HRMS_SYSTEM_PROMPT = """<|im_start|>system
 
32
  You are a function calling AI agent with self-recursion.
33
  You can call only one function at a time and analyse data you get from function response.
34
  You are provided with function signatures within <tools></tools> XML tags.
 
35
 
36
  You may use agentic frameworks for reasoning and planning to help with user query.
37
  Please call a function and wait for function results to be provided to you in the next iteration.
 
66
  {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
67
  </tool_call>
68
 
69
+ Example 3:
70
+ User: How long will it take to get to the destination?
71
+ Assistant:
72
+ <tool_call>
73
+ {{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
74
 
75
+ When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
76
+ Always assume user wants to travel by car.
77
 
78
  Use the following pydantic model json schema for each tool call you will make:
79
  {schema}
 
150
  pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
151
 
152
  if not text.startswith("<tool_call>"):
153
+ if "<tool_call>" in text:
154
+ raise ValueError("<text_and_tool_call>")
155
  return [], []
156
 
157
  matches = re.findall(pattern, text, re.DOTALL)
 
171
  def process_response(user_query, res, history, tools, depth):
172
  """Returns True if the response contains tool calls, False otherwise."""
173
  logger.debug(f"Processing response: {res}")
174
+ tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
175
+ tool_call_id = uuid.uuid4().hex
176
+ try:
177
+ tool_calls, errors = parse_tool_calls(res)
178
+ except ValueError as e:
179
+ if "<text_and_tool_call>" in str(e):
180
+ tool_results += f"A mix of text and tool_call was found, you must either answer the query in a short sentence or use tool_call not both. Try again, this time only using tool_call."
181
+ history.add_message(
182
+ ToolMessage(content=tool_results, tool_call_id=tool_call_id)
183
+ )
184
+ return True, [], []
185
  # TODO: Handle errors
186
  if not tool_calls:
187
  return False, tool_calls, errors
188
  # tool_results = ""
189
+
190
  for tool_call in tool_calls:
191
  # TODO: Extra Validation
192
  # Call the function
 
202
 
203
  tool_results = tool_results.strip()
204
  print(f"Tool results: {tool_results}")
 
205
  history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
206
  return True, tool_calls, errors
207
 
208
 
209
+ def run_inference_step(depth, history, tools, schema_json, dry_run=False):
210
  # If we decide to call a function, we need to generate the prompt for the model
211
  # based on the history of the conversation so far.
212
  # not break the loop
 
215
  print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
216
 
217
  data = {
218
+ "prompt": prompt
219
+ + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
220
+ + AI_PREAMBLE,
221
  # "streaming": False,
222
  # "model": "smangrul/llama-3-8b-instruct-function-calling",
223
  # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
224
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
225
+ # "model": "interstellarninja/hermes-2-pro-llama-3-8b",
226
+ "model": "dolphin-llama3:8b",
227
+ # "model": "dolphin-llama3:70b",
228
  "raw": True,
229
  "options": {
230
  "temperature": 0.8,
231
  # "max_tokens": 1500,
232
  "num_predict": 1500,
233
+ "mirostat": 1,
234
+ # "mirostat_tau": 2,
235
+ "repeat_penalty": 1.5,
236
+ "top_k": 25,
237
+ "top_p": 0.5,
238
  # "num_predict": 1500,
239
  # "max_tokens": 1500,
240
  },
 
243
  if dry_run:
244
  print(prompt + AI_PREAMBLE)
245
  return "Didn't really run it."
246
+
247
+ client = Client(host='http://localhost:11444')
248
+ # out = ollama.generate(**data)
249
+ out = client.generate(**data)
250
  logger.debug(f"Response from model: {out}")
251
  res = out["response"]
252
 
 
254
 
255
 
256
  def process_query(user_query: str, history: ChatMessageHistory, tools):
257
+ # Add vehicle status to the history
258
+ user_query_status = (
259
+ f"Given that:\n{vehicle_status()[0]}\nAnswer the following:\n{user_query}"
260
+ )
261
+ history.add_message(HumanMessage(content=user_query_status))
262
  for depth in range(10):
263
+ out = run_inference_step(depth, history, tools, schema_json)
264
  print(f"Inference step result:\n{out}\n------------------\n")
265
  history.add_message(AIMessage(content=out))
266
  to_continue, tool_calls, errors = process_response(
267
  user_query, out, history, tools, depth
268
  )
269
  if errors:
270
+ history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))
 
 
271
 
272
  if not to_continue:
273
  print(f"This is the answer, no more iterations: {out}")
kitt/skills/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .weather import get_weather_current_location, get_weather, get_forecast
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
  from .vehicle import vehicle_status
 
9
 
10
 
11
 
 
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
  from .vehicle import vehicle_status
9
+ from .interpreter import code_interpreter
10
 
11
 
12
 
kitt/skills/interpreter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ # From https://github.com/NousResearch/Hermes-Function-Calling
4
+
5
+ def code_interpreter(code_markdown: str) -> dict | str:
6
+ """
7
+ Execute the provided Python code string on the terminal using exec.
8
+
9
+ The string should contain valid, executable and pure Python code in markdown syntax.
10
+ Code should also import any required Python packages.
11
+
12
+ Args:
13
+ code_markdown (str): The Python code with markdown syntax to be executed.
14
+ For example: ```python\n<code-string>\n```
15
+
16
+ Returns:
17
+ dict | str: A dictionary containing variables declared and values returned by function calls,
18
+ or an error message if an exception occurred.
19
+
20
+ Note:
21
+ Use this function with caution, as executing arbitrary code can pose security risks. Use it only for numerical calculations.
22
+ """
23
+ try:
24
+ # Extracting code from Markdown code block
25
+ code_lines = code_markdown.split('\n')[1:-1]
26
+ code_without_markdown = '\n'.join(code_lines)
27
+
28
+ # Create a new namespace for code execution
29
+ exec_namespace = {}
30
+
31
+ # Execute the code in the new namespace
32
+ exec(code_without_markdown, exec_namespace)
33
+
34
+ # Collect variables and function call results
35
+ result_dict = {}
36
+ for name, value in exec_namespace.items():
37
+ if callable(value):
38
+ try:
39
+ result_dict[name] = value()
40
+ except TypeError:
41
+ # If the function requires arguments, attempt to call it with arguments from the namespace
42
+ arg_names = inspect.getfullargspec(value).args
43
+ args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names}
44
+ result_dict[name] = value(**args)
45
+ elif not name.startswith('_'): # Exclude variables starting with '_'
46
+ result_dict[name] = value
47
+
48
+ return result_dict
49
+
50
+ except Exception as e:
51
+ error_message = f"An error occurred: {e}"
52
+ return error_message
kitt/skills/routing.py CHANGED
@@ -90,10 +90,41 @@ def find_route_tomtom(
90
  }, response
91
 
92
 
93
- def find_route(destination=""):
94
- """This function finds a route to a destination and returns the distance and the estimated time to go to a specific destination\
95
- from the current location.
96
- :param destination (string): Required. The destination
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
  if not destination:
99
  destination = vehicle.destination
@@ -114,7 +145,13 @@ def find_route(destination=""):
114
  trip_info, raw_response = find_route_tomtom(
115
  lat_depart, lon_depart, lat_dest, lon_dest, departure_time
116
  )
 
 
 
 
 
117
 
 
118
  distance, duration, arrival_time = (
119
  trip_info["distance_m"],
120
  trip_info["duration_s"],
@@ -138,5 +175,4 @@ def find_route(destination=""):
138
  arrival_hour_display = arrival_time.strftime("%H:%M")
139
 
140
  # return the distance and time
141
- return f"The route to {destination} is {distance_km:.2f} km which takes {time_display}. Leaving now, the arrival time is estimated at {arrival_hour_display}."
142
- # raw_response["routes"][0]["legs"][0]["points"]
 
90
  }, response
91
 
92
 
93
+ def find_route_a_to_b(origin="", destination=""):
94
+ """Get a route between origin and destination.
95
+
96
+ Args:
97
+ origin (string): Optional. The origin name.
98
+ destination (string): Optional. The destination name.
99
+ """
100
+ if not destination:
101
+ destination = vehicle.destination
102
+ lat_dest, lon_dest = find_coordinates(destination)
103
+ print(f"lat_dest: {lat_dest}, lon_dest: {lon_dest}")
104
+
105
+ if not origin:
106
+ # Extract the latitude and longitude of the vehicle
107
+ vehicle_coordinates = getattr(vehicle, "location_coordinates")
108
+ lat_depart, lon_depart = vehicle_coordinates
109
+ else:
110
+ lat_depart, lon_depart = find_coordinates(origin)
111
+ print(f"lat_depart: {lat_depart}, lon_depart: {lon_depart}")
112
+
113
+ date = getattr(vehicle, "date")
114
+ time = getattr(vehicle, "time")
115
+ departure_time = f"{date}T{time}"
116
+
117
+ trip_info, raw_response = find_route_tomtom(
118
+ lat_depart, lon_depart, lat_dest, lon_dest, departure_time
119
+ )
120
+ return _format_tomtom_trip_info(trip_info, destination)
121
+
122
+
123
+ def find_route(destination):
124
+ """Get a route to a destination from the current location of the vehicle.
125
+
126
+ Args:
127
+ destination (string): Optional. The destination name.
128
  """
129
  if not destination:
130
  destination = vehicle.destination
 
145
  trip_info, raw_response = find_route_tomtom(
146
  lat_depart, lon_depart, lat_dest, lon_dest, departure_time
147
  )
148
+ return _format_tomtom_trip_info(trip_info, destination)
149
+
150
+
151
+ # raw_response["routes"][0]["legs"][0]["points"]
152
+
153
 
154
+ def _format_tomtom_trip_info(trip_info, destination="destination"):
155
  distance, duration, arrival_time = (
156
  trip_info["distance_m"],
157
  trip_info["duration_s"],
 
175
  arrival_hour_display = arrival_time.strftime("%H:%M")
176
 
177
  # return the distance and time
178
+ return f"The route to {destination} is {distance_km:.2f} km which takes {time_display}. Leaving now, the arrival time is estimated at {arrival_hour_display}."
 
kitt/skills/vehicle.py CHANGED
@@ -1,13 +1,9 @@
1
  from .common import vehicle
2
 
3
 
4
- STATUS_TEMPLATE = """
5
- The current location is:{location}
6
- The current Geo coordinates: {lat}, {lon}
7
- The current time: {time}
8
- The current date: {date}
9
- The current destination is: {destination}
10
- """
11
 
12
 
13
  def vehicle_status() -> tuple[str, dict[str, str]]:
 
1
  from .common import vehicle
2
 
3
 
4
+ STATUS_TEMPLATE = """The current location is: {location} ({lat}, {lon})
5
+ The current date and time: {date} {time}
6
+ The current destination is: {destination}"""
 
 
 
 
7
 
8
 
9
  def vehicle_status() -> tuple[str, dict[str, str]]:
main.py CHANGED
@@ -26,6 +26,7 @@ from kitt.skills import (
26
  do_anything_else,
27
  date_time_info,
28
  get_weather_current_location,
 
29
  )
30
  from kitt.skills import extract_func_args
31
  from kitt.core import voice_options, tts_gradio
@@ -124,6 +125,7 @@ tools = [
124
  StructuredTool.from_function(search_along_route),
125
  StructuredTool.from_function(date_time_info),
126
  StructuredTool.from_function(get_weather_current_location),
 
127
  # StructuredTool.from_function(do_anything_else),
128
  ]
129
 
@@ -201,6 +203,8 @@ def run_model(query, voice_character, state):
201
  return run_nexusraven_model(query, voice_character)
202
  elif model == "llama3":
203
  return run_llama3_model(query, voice_character)
 
 
204
 
205
 
206
  def calculate_route_gradio(origin, destination):
@@ -259,12 +263,19 @@ def save_and_transcribe_audio(audio):
259
  y = y.astype(np.float32)
260
  y /= np.max(np.abs(y))
261
  text = transcriber({"sampling_rate": sr, "raw": y})["text"]
 
 
262
  except Exception as e:
263
  print(f"Error: {e}")
264
- return "Error transcribing audio"
265
  return text
266
 
267
 
 
 
 
 
 
268
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
269
  # in "Insecure origins treated as secure", enable it and relaunch chrome
270
 
@@ -337,6 +348,18 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
337
  input_text = gr.Textbox(
338
  value="How is the weather?", label="Input text", interactive=True
339
  )
 
 
 
 
 
 
 
 
 
 
 
 
340
  vehicle_status = gr.JSON(
341
  value=vehicle.model_dump_json(), label="Vehicle status"
342
  )
@@ -370,6 +393,11 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
370
  inputs=[input_text, voice_character, state],
371
  outputs=[output_text, output_audio],
372
  )
 
 
 
 
 
373
 
374
  # Set the vehicle status based on the trip progress
375
  trip_progress.release(
@@ -380,7 +408,10 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
380
 
381
  # Save and transcribe the audio
382
  input_audio.stop_recording(
383
- fn=save_and_transcribe_audio, inputs=[input_audio], outputs=[input_text]
 
 
 
384
  )
385
 
386
  # Clear the history
 
26
  do_anything_else,
27
  date_time_info,
28
  get_weather_current_location,
29
+ code_interpreter,
30
  )
31
  from kitt.skills import extract_func_args
32
  from kitt.core import voice_options, tts_gradio
 
125
  StructuredTool.from_function(search_along_route),
126
  StructuredTool.from_function(date_time_info),
127
  StructuredTool.from_function(get_weather_current_location),
128
+ StructuredTool.from_function(code_interpreter),
129
  # StructuredTool.from_function(do_anything_else),
130
  ]
131
 
 
203
  return run_nexusraven_model(query, voice_character)
204
  elif model == "llama3":
205
  return run_llama3_model(query, voice_character)
206
+ return "Error running model", None
207
+
208
 
209
 
210
  def calculate_route_gradio(origin, destination):
 
263
  y = y.astype(np.float32)
264
  y /= np.max(np.abs(y))
265
  text = transcriber({"sampling_rate": sr, "raw": y})["text"]
266
+ gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
267
+
268
  except Exception as e:
269
  print(f"Error: {e}")
270
+ return "Error transcribing audio."
271
  return text
272
 
273
 
274
+ def save_and_transcribe_run_model(audio, voice_character, state):
275
+ text = save_and_transcribe_audio(audio)
276
+ out_text, out_voice = run_model(text, voice_character, state)
277
+ return text, out_text, out_voice
278
+
279
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
280
  # in "Insecure origins treated as secure", enable it and relaunch chrome
281
 
 
348
  input_text = gr.Textbox(
349
  value="How is the weather?", label="Input text", interactive=True
350
  )
351
+ with gr.Accordion("Debug"):
352
+ input_audio_debug = gr.Audio(
353
+ type="numpy",
354
+ sources=["microphone"],
355
+ label="Input audio",
356
+ elem_id="input_audio",
357
+ )
358
+ input_text_debug = gr.Textbox(
359
+ value="How is the weather?",
360
+ label="Input text",
361
+ interactive=True,
362
+ )
363
  vehicle_status = gr.JSON(
364
  value=vehicle.model_dump_json(), label="Vehicle status"
365
  )
 
393
  inputs=[input_text, voice_character, state],
394
  outputs=[output_text, output_audio],
395
  )
396
+ input_text_debug.submit(
397
+ fn=run_model,
398
+ inputs=[input_text, voice_character, state],
399
+ outputs=[output_text, output_audio],
400
+ )
401
 
402
  # Set the vehicle status based on the trip progress
403
  trip_progress.release(
 
408
 
409
  # Save and transcribe the audio
410
  input_audio.stop_recording(
411
+ fn=save_and_transcribe_run_model, inputs=[input_audio, voice_character, state], outputs=[input_text, output_text, output_audio]
412
+ )
413
+ input_audio_debug.stop_recording(
414
+ fn=save_and_transcribe_audio, inputs=[input_audio_debug], outputs=[input_text_debug]
415
  )
416
 
417
  # Clear the history