sasan commited on
Commit
e3db752
·
1 Parent(s): 540996c

chore: Update TTS dependencies and remove unused imports

Browse files
kitt/core/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
- from collections import namedtuple
3
- import time
4
  import pathlib
 
 
5
  from typing import List
6
 
7
  import numpy as np
8
  import torch
9
- # from TTS.api import TTS
10
 
11
  os.environ["COQUI_TOS_AGREED"] = "1"
12
 
@@ -18,7 +18,10 @@ file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
18
 
19
  voices = [
20
  Voice(
21
- "Fast", neutral=None, angry=None, speed=1.0,
 
 
 
22
  ),
23
  Voice(
24
  "Attenborough",
 
1
  import os
 
 
2
  import pathlib
3
+ import time
4
+ from collections import namedtuple
5
  from typing import List
6
 
7
  import numpy as np
8
  import torch
9
+ from TTS.api import TTS
10
 
11
  os.environ["COQUI_TOS_AGREED"] = "1"
12
 
 
18
 
19
  voices = [
20
  Voice(
21
+ "Fast",
22
+ neutral="empty",
23
+ angry=None,
24
+ speed=1.0,
25
  ),
26
  Voice(
27
  "Attenborough",
kitt/core/model.py CHANGED
@@ -2,20 +2,21 @@ import ast
2
  import json
3
  import re
4
  import uuid
 
5
  from enum import Enum
6
  from typing import List
7
- import xml.etree.ElementTree as ET
8
 
9
  from langchain.memory import ChatMessageHistory
10
- from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
11
- from langchain_core.utils.function_calling import convert_to_openai_tool
12
  from langchain.tools.base import StructuredTool
 
 
 
13
  from ollama import Client
14
  from pydantic import BaseModel
15
- from loguru import logger
16
 
17
  from kitt.skills import vehicle_status
18
  from kitt.skills.common import config
 
19
  from .validator import validate_function_call_schema
20
 
21
 
@@ -83,8 +84,9 @@ Once you have called a function, results will be fed back to you within <tool_re
83
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
84
  Analyze the data once you get the results and call another function.
85
  At each iteration please continue adding the your analysis to previous summary.
86
- Your final response should directly answer the user query. Don't tell what you are doing, just do it.
87
- Keep your responses very concise and to the point. Don't provide any unnecessary information. Don't refer to user preferences as <user_preferences>.
 
88
 
89
 
90
  Tools:
@@ -131,6 +133,16 @@ Assistant:
131
  {{"arguments": {{"destination": "Paris"}}, "name": "set_vehicle_destination"}}
132
  </tool_call>
133
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  Instructions:
136
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
@@ -228,9 +240,6 @@ def get_prompt(template, history, tools, schema, user_preferences, car_status=No
228
  return prompt
229
 
230
 
231
-
232
-
233
-
234
  def run_inference_ollama(prompt):
235
  data = {
236
  "prompt": prompt,
 
2
  import json
3
  import re
4
  import uuid
5
+ import xml.etree.ElementTree as ET
6
  from enum import Enum
7
  from typing import List
 
8
 
9
  from langchain.memory import ChatMessageHistory
 
 
10
  from langchain.tools.base import StructuredTool
11
+ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
12
+ from langchain_core.utils.function_calling import convert_to_openai_tool
13
+ from loguru import logger
14
  from ollama import Client
15
  from pydantic import BaseModel
 
16
 
17
  from kitt.skills import vehicle_status
18
  from kitt.skills.common import config
19
+
20
  from .validator import validate_function_call_schema
21
 
22
 
 
84
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
85
  Analyze the data once you get the results and call another function.
86
  At each iteration please continue adding the your analysis to previous summary.
87
+ Your final response should directly answer the user query. Don't tell what you are doing, just do it. Do your best to keep your responses to about 1 line. Avoid asking follow up questions as much as possible.
88
+ Keep your responses very concise and to the point. Don't provide any unnecessary information. Do not offer to help with anything other than the user query.
89
+ Don't refer to user preferences as <user_preferences>.
90
 
91
 
92
  Tools:
 
133
  {{"arguments": {{"destination": "Paris"}}, "name": "set_vehicle_destination"}}
134
  </tool_call>
135
 
136
+ Example 5:
137
+ User: Which place is warmer and by how much, dubai or tokyo?
138
+ Assistant:
139
+ <tool_call>
140
+ {{"arguments": {{"location": "Tokyo"}}, "name": "get_weather"}}
141
+ </tool_call>
142
+ <tool_call>
143
+ {{"arguments": {{"location": "Dubai"}}, "name": "get_weather"}}
144
+ </tool_call>
145
+
146
 
147
  Instructions:
148
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
 
240
  return prompt
241
 
242
 
 
 
 
243
  def run_inference_ollama(prompt):
244
  data = {
245
  "prompt": prompt,
kitt/core/tts.py CHANGED
@@ -1,14 +1,14 @@
1
  from collections import namedtuple
2
- from replicate import Client
3
- from loguru import logger
4
- from kitt.skills.common import config
5
- import torch
6
 
7
- from parler_tts import ParlerTTSForConditionalGeneration
8
- from transformers import AutoTokenizer
9
  import soundfile as sf
 
 
10
  from melo.api import TTS as MeloTTS
 
 
 
11
 
 
12
 
13
  replicate = Client(api_token=config.REPLICATE_API_KEY)
14
 
@@ -16,7 +16,10 @@ Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
16
 
17
  voices_replicate = [
18
  Voice(
19
- "Fast", neutral=None, angry=None, speed=1.0,
 
 
 
20
  ),
21
  Voice(
22
  "Attenborough",
 
1
  from collections import namedtuple
 
 
 
 
2
 
 
 
3
  import soundfile as sf
4
+ import torch
5
+ from loguru import logger
6
  from melo.api import TTS as MeloTTS
7
+ from parler_tts import ParlerTTSForConditionalGeneration
8
+ from replicate import Client
9
+ from transformers import AutoTokenizer
10
 
11
+ from kitt.skills.common import config
12
 
13
  replicate = Client(api_token=config.REPLICATE_API_KEY)
14
 
 
16
 
17
  voices_replicate = [
18
  Voice(
19
+ "Fast",
20
+ neutral="empty",
21
+ angry=None,
22
+ speed=1.0,
23
  ),
24
  Voice(
25
  "Attenborough",
kitt/core/utils.py CHANGED
@@ -1,11 +1,11 @@
1
  import json
2
  import re
3
- from typing import List, Tuple, Optional, Union
4
 
5
 
6
  def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
7
  import plotly.express as px
8
-
9
  lats = []
10
  lons = []
11
 
@@ -15,9 +15,7 @@ def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
15
  # fig = px.line_geo(lat=lats, lon=lons)
16
  # fig.update_geos(fitbounds="locations")
17
 
18
- fig = px.line_mapbox(
19
- lat=lats, lon=lons, zoom=12, height=600, color_discrete_sequence=["red"]
20
- )
21
 
22
  if vehicle:
23
  fig.add_trace(
@@ -33,21 +31,21 @@ def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
33
  # mapbox_zoom=12,
34
  )
35
  fig.update_geos(fitbounds="locations")
36
- fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
37
  return fig
38
 
39
 
40
  def extract_json_from_markdown(text):
41
  """
42
  Extracts the JSON string from the given text using a regular expression pattern.
43
-
44
  Args:
45
  text (str): The input text containing the JSON string.
46
-
47
  Returns:
48
  dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
49
  """
50
- json_pattern = r'```json\r?\n(.*?)\r?\n```'
51
  match = re.search(json_pattern, text, re.DOTALL)
52
  if match:
53
  json_string = match.group(1)
@@ -58,4 +56,4 @@ def extract_json_from_markdown(text):
58
  print(f"Error decoding JSON string: {e}")
59
  else:
60
  print("JSON string not found in the text.")
61
- return None
 
1
  import json
2
  import re
3
+ from typing import List, Optional, Tuple, Union
4
 
5
 
6
  def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
7
  import plotly.express as px
8
+
9
  lats = []
10
  lons = []
11
 
 
15
  # fig = px.line_geo(lat=lats, lon=lons)
16
  # fig.update_geos(fitbounds="locations")
17
 
18
+ fig = px.line_mapbox(lat=lats, lon=lons, color_discrete_sequence=["red"])
 
 
19
 
20
  if vehicle:
21
  fig.add_trace(
 
31
  # mapbox_zoom=12,
32
  )
33
  fig.update_geos(fitbounds="locations")
34
+ fig.update_layout(height=600, margin={"r": 20, "t": 20, "l": 20, "b": 20})
35
  return fig
36
 
37
 
38
  def extract_json_from_markdown(text):
39
  """
40
  Extracts the JSON string from the given text using a regular expression pattern.
41
+
42
  Args:
43
  text (str): The input text containing the JSON string.
44
+
45
  Returns:
46
  dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
47
  """
48
+ json_pattern = r"```json\r?\n(.*?)\r?\n```"
49
  match = re.search(json_pattern, text, re.DOTALL)
50
  if match:
51
  json_string = match.group(1)
 
56
  print(f"Error decoding JSON string: {e}")
57
  else:
58
  print("JSON string not found in the text.")
59
+ return None
kitt/skills/poi.py CHANGED
@@ -1,8 +1,10 @@
1
  import json
2
  import urllib.parse
 
3
  import requests
4
- from loguru import logger
5
  from langchain.tools import tool
 
 
6
  from .common import config, vehicle
7
 
8
 
@@ -20,7 +22,7 @@ def _select_equally_spaced_coordinates(coords, number_of_points=10):
20
 
21
 
22
  @tool
23
- def search_points_of_interest(search_query: str ="french restaurant"):
24
  """
25
  Get some of the closest points of interest matching the query.
26
 
@@ -47,7 +49,7 @@ def search_points_of_interest(search_query: str ="french restaurant"):
47
  "lon": lon,
48
  "radius": 5000,
49
  "idxSet": "POI",
50
- "limit": 50
51
  }
52
 
53
  r = requests.get(url, params=params, timeout=5)
@@ -76,7 +78,7 @@ def search_points_of_interest(search_query: str ="french restaurant"):
76
  output = (
77
  f"There are {len(results)} options in the vicinity. The most relevant are: "
78
  )
79
- return output + ".\n ".join(formatted_results), results[:3]
80
 
81
 
82
  def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
@@ -96,7 +98,6 @@ def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
96
 
97
  r = requests.get(url, timeout=5)
98
 
99
-
100
  # Parse JSON from the response
101
  data = r.json()
102
  # print(data)
 
1
  import json
2
  import urllib.parse
3
+
4
  import requests
 
5
  from langchain.tools import tool
6
+ from loguru import logger
7
+
8
  from .common import config, vehicle
9
 
10
 
 
22
 
23
 
24
  @tool
25
+ def search_points_of_interest(search_query: str = "french restaurant"):
26
  """
27
  Get some of the closest points of interest matching the query.
28
 
 
49
  "lon": lon,
50
  "radius": 5000,
51
  "idxSet": "POI",
52
+ "limit": 50,
53
  }
54
 
55
  r = requests.get(url, params=params, timeout=5)
 
78
  output = (
79
  f"There are {len(results)} options in the vicinity. The most relevant are: "
80
  )
81
+ return output + ".\n ".join(formatted_results), [x["poi"] for x in results[:3]]
82
 
83
 
84
  def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
 
98
 
99
  r = requests.get(url, timeout=5)
100
 
 
101
  # Parse JSON from the response
102
  data = r.json()
103
  # print(data)
kitt/skills/routing.py CHANGED
@@ -1,7 +1,9 @@
1
  from datetime import datetime
 
2
  import requests
3
- from loguru import logger
4
  from langchain.tools import tool
 
 
5
  from .common import config, vehicle
6
 
7
 
@@ -12,13 +14,29 @@ def find_coordinates(address):
12
  """
13
  # https://developer.tomtom.com/geocoding-api/documentation/geocode
14
  url = f"https://api.tomtom.com/search/2/geocode/{address}.json?key={config.TOMTOM_API_KEY}"
15
- response = requests.get(url)
16
  data = response.json()
17
  lat = data["results"][0]["position"]["lat"]
18
  lon = data["results"][0]["position"]["lon"]
19
  return lat, lon
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def calculate_route(origin, destination):
23
  """This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
24
  print(f"calculate_route(origin: {origin}, destination: {destination})")
@@ -37,7 +55,7 @@ def calculate_route(origin, destination):
37
  # destination = "49.586745,6.140002"
38
 
39
  url = f"https://api.tomtom.com/routing/1/calculateRoute/{orig_coords_str}:{dest_coords_str}/json?key={config.TOMTOM_API_KEY}"
40
- response = requests.get(url)
41
  data = response.json()
42
  points = data["routes"][0]["legs"][0]["points"]
43
 
@@ -150,7 +168,6 @@ def find_route(destination):
150
  )
151
  return _format_tomtom_trip_info(trip_info, destination)
152
 
153
-
154
  # raw_response["routes"][0]["legs"][0]["points"]
155
 
156
 
@@ -178,4 +195,4 @@ def _format_tomtom_trip_info(trip_info, destination="destination"):
178
  arrival_hour_display = arrival_time.strftime("%H:%M")
179
 
180
  # return the distance and time
181
- return f"The route to {destination} is {distance_km:.2f} km which takes {time_display}. Leaving now, the arrival time is estimated at {arrival_hour_display}."
 
1
  from datetime import datetime
2
+
3
  import requests
 
4
  from langchain.tools import tool
5
+ from loguru import logger
6
+
7
  from .common import config, vehicle
8
 
9
 
 
14
  """
15
  # https://developer.tomtom.com/geocoding-api/documentation/geocode
16
  url = f"https://api.tomtom.com/search/2/geocode/{address}.json?key={config.TOMTOM_API_KEY}"
17
+ response = requests.get(url, timeout=5)
18
  data = response.json()
19
  lat = data["results"][0]["position"]["lat"]
20
  lon = data["results"][0]["position"]["lon"]
21
  return lat, lon
22
 
23
 
24
+ def find_address(lat, lon):
25
+ """
26
+ Find the address of a specific location.
27
+
28
+ Args:
29
+ lat (string): Required. The latitude
30
+ lon (string): Required. The longitude
31
+ """
32
+ # https://developer.tomtom.com/search-api/documentation/reverse-geocoding
33
+ url = f"https://api.tomtom.com/search/2/reverseGeocode/{lat},{lon}.json?key={config.TOMTOM_API_KEY}"
34
+ response = requests.get(url, timeout=5)
35
+ data = response.json()
36
+ address = data["addresses"][0]["address"]["freeformAddress"]
37
+ return address
38
+
39
+
40
  def calculate_route(origin, destination):
41
  """This function is called when the origin or destination is updated in the GUI. It calculates the route between the origin and destination."""
42
  print(f"calculate_route(origin: {origin}, destination: {destination})")
 
55
  # destination = "49.586745,6.140002"
56
 
57
  url = f"https://api.tomtom.com/routing/1/calculateRoute/{orig_coords_str}:{dest_coords_str}/json?key={config.TOMTOM_API_KEY}"
58
+ response = requests.get(url, timeout=5)
59
  data = response.json()
60
  points = data["routes"][0]["legs"][0]["points"]
61
 
 
168
  )
169
  return _format_tomtom_trip_info(trip_info, destination)
170
 
 
171
  # raw_response["routes"][0]["legs"][0]["points"]
172
 
173
 
 
195
  arrival_hour_display = arrival_time.strftime("%H:%M")
196
 
197
  # return the distance and time
198
+ return f"The route to {destination} is {distance_km:.2f} km which takes {time_display}. Leaving now, the arrival time is estimated at {arrival_hour_display}."
kitt/skills/weather.py CHANGED
@@ -1,6 +1,6 @@
1
  import requests
2
- from loguru import logger
3
  from langchain.tools import tool
 
4
 
5
  from .common import config, vehicle
6
 
 
1
  import requests
 
2
  from langchain.tools import tool
3
+ from loguru import logger
4
 
5
  from .common import config, vehicle
6
 
main.py CHANGED
@@ -1,49 +1,65 @@
1
  import time
 
2
  import gradio as gr
3
  import numpy as np
 
4
  import torch
5
  import torchaudio
6
- from transformers import pipeline
7
  import typer
8
-
9
- from kitt.skills.common import config, vehicle
10
- from kitt.skills.routing import calculate_route
11
- from kitt.core.tts import run_tts_replicate, run_tts_fast, run_melo_tts
12
- import ollama
13
-
14
- from langchain.tools.base import StructuredTool
15
  from langchain.memory import ChatMessageHistory
16
- from langchain_core.utils.function_calling import convert_to_openai_tool
17
  from langchain.tools import tool
 
 
18
  from loguru import logger
 
19
 
 
 
 
20
 
 
 
 
21
  from kitt.skills import (
22
- get_weather,
 
 
 
23
  find_route,
24
  get_forecast,
25
- vehicle_status as vehicle_status_fn,
26
- set_vehicle_speed,
27
- search_points_of_interest,
28
  search_along_route_w_coordinates,
 
29
  set_vehicle_destination,
30
- do_anything_else,
31
- date_time_info,
32
- get_weather_current_location,
33
- code_interpreter,
34
  )
35
- from kitt.skills import extract_func_args
36
- from kitt.core import voice_options, tts_gradio
37
-
38
- # from kitt.core.model import process_query
39
- from kitt.core.model import generate_function_call as process_query
40
- from kitt.core import utils as kitt_utils
41
 
 
 
 
 
 
 
 
42
 
43
  global_context = {
44
  "vehicle": vehicle,
45
  "query": "How is the weather?",
46
  "route_points": [],
 
 
 
 
 
 
 
 
 
 
47
  }
48
 
49
  speaker_embedding_cache = {}
@@ -72,8 +88,6 @@ Answer questions concisely and do not mention what you base your reply on.<|im_e
72
  <|im_start|>assistant
73
  """
74
 
75
- USER_PREFERENCES = "I love italian food\nI like doing sports"
76
-
77
 
78
  def get_prompt(template, input, history, tools):
79
  # "vehicle_status": vehicle_status_fn()[0]
@@ -221,7 +235,7 @@ def run_llama3_model(query, voice_character, state):
221
  if state["tts_enabled"]:
222
  # voice_out = run_tts_replicate(output_text, voice_character)
223
  # voice_out = run_tts_fast(output_text)[0]
224
- voice_out = run_melo_tts(output_text, voice_character)
225
  # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
226
  return (
227
  output_text,
@@ -245,33 +259,47 @@ def run_model(query, voice_character, state):
245
 
246
  if not state["enable_history"]:
247
  history.clear()
248
- return text, voice, vehicle.model_dump_json()
 
 
 
 
 
 
 
 
249
 
250
 
251
  def calculate_route_gradio(origin, destination):
252
  vehicle_status, points = calculate_route(origin, destination)
253
  plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
254
  global_context["route_points"] = points
 
255
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
256
  return plot, vehicle_status, 0
257
 
258
 
259
- def update_vehicle_status(trip_progress, origin, destination):
260
  if not global_context["route_points"]:
261
  vehicle_status, points = calculate_route(origin, destination)
262
  global_context["route_points"] = points
 
 
263
  n_points = len(global_context["route_points"])
264
  index = min(int(trip_progress / 100 * n_points), n_points - 1)
265
- print(f"Trip progress: {trip_progress} len: {n_points}, index: {index}")
266
  new_coords = global_context["route_points"][index]
267
  new_coords = new_coords["latitude"], new_coords["longitude"]
268
- print(f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}")
 
 
269
  vehicle.location_coordinates = new_coords
270
- vehicle.location = ""
 
271
  plot = kitt_utils.plot_route(
272
  global_context["route_points"], vehicle=vehicle.location_coordinates
273
  )
274
- return vehicle.model_dump_json(), plot
275
 
276
 
277
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -314,8 +342,10 @@ def save_and_transcribe_audio(audio):
314
 
315
  def save_and_transcribe_run_model(audio, voice_character, state):
316
  text = save_and_transcribe_audio(audio)
317
- out_text, out_voice, vehicle_status = run_model(text, voice_character, state)
318
- return text, out_text, out_voice, vehicle_status
 
 
319
 
320
 
321
  def set_tts_enabled(tts_enabled, state):
@@ -324,6 +354,7 @@ def set_tts_enabled(tts_enabled, state):
324
  f"TTS enabled was {state['tts_enabled']} and changed to {new_tts_enabled}"
325
  )
326
  state["tts_enabled"] = new_tts_enabled
 
327
  return state
328
 
329
 
@@ -333,6 +364,7 @@ def set_llm_backend(llm_backend, state):
333
  f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
334
  )
335
  state["llm_backend"] = new_llm_backend
 
336
  return state
337
 
338
 
@@ -340,6 +372,7 @@ def set_user_preferences(preferences, state):
340
  new_preferences = preferences
341
  logger.info(f"User preferences changed to: {new_preferences}")
342
  state["user_preferences"] = new_preferences
 
343
  return state
344
 
345
 
@@ -349,9 +382,40 @@ def set_enable_history(enable_history, state):
349
  f"Enable history was {state['enable_history']} and changed to {new_enable_history}"
350
  )
351
  state["enable_history"] = new_enable_history
 
 
 
 
 
 
 
 
 
 
 
352
  return state
353
 
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
356
  # in "Insecure origins treated as secure", enable it and relaunch chrome
357
 
@@ -360,13 +424,6 @@ def set_enable_history(enable_history, state):
360
  # What's the closest restaurant from here?
361
 
362
 
363
- ORIGIN = "Mondorf-les-Bains, Luxembourg"
364
- DESTINATION = "Rue Alphonse Weicker, Luxembourg"
365
- DEFAULT_LLM_BACKEND = "ollama"
366
- ENABLE_HISTORY = True
367
- ENABLE_TTS = True
368
-
369
-
370
  def create_demo(tts_server: bool = False, model="llama3"):
371
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
372
  with gr.Blocks(theme=gr.themes.Default()) as demo:
@@ -380,10 +437,13 @@ def create_demo(tts_server: bool = False, model="llama3"):
380
  "llm_backend": DEFAULT_LLM_BACKEND,
381
  "user_preferences": USER_PREFERENCES,
382
  "enable_history": ENABLE_HISTORY,
 
 
383
  }
384
  )
385
- trip_points = gr.State(value=[])
386
  plot, vehicle_status, _ = calculate_route_gradio(ORIGIN, DESTINATION)
 
387
 
388
  with gr.Row():
389
  with gr.Column(scale=1, min_width=300):
@@ -452,6 +512,10 @@ def create_demo(tts_server: bool = False, model="llama3"):
452
  label="Input text",
453
  interactive=True,
454
  )
 
 
 
 
455
  vehicle_status = gr.JSON(
456
  value=vehicle.model_dump_json(), label="Vehicle status"
457
  )
@@ -462,6 +526,12 @@ def create_demo(tts_server: bool = False, model="llama3"):
462
  value="Yes" if ENABLE_TTS else "No",
463
  interactive=True,
464
  )
 
 
 
 
 
 
465
  llm_backend = gr.Radio(
466
  choices=["Ollama", "Replicate"],
467
  label="LLM Backend",
@@ -505,26 +575,34 @@ def create_demo(tts_server: bool = False, model="llama3"):
505
  input_text.submit(
506
  fn=run_model,
507
  inputs=[input_text, voice_character, state],
508
- outputs=[output_text, output_audio, vehicle_status],
509
  )
510
  input_text_debug.submit(
511
  fn=run_model,
512
  inputs=[input_text_debug, voice_character, state],
513
- outputs=[output_text, output_audio, vehicle_status],
514
  )
515
 
516
  # Set the vehicle status based on the trip progress
517
  trip_progress.release(
518
  fn=update_vehicle_status,
519
- inputs=[trip_progress, origin, destination],
520
- outputs=[vehicle_status, map_plot],
521
  )
522
 
523
  # Save and transcribe the audio
524
  input_audio.stop_recording(
525
  fn=save_and_transcribe_run_model,
526
  inputs=[input_audio, voice_character, state],
527
- outputs=[input_text, output_text, output_audio, vehicle_status],
 
 
 
 
 
 
 
 
528
  )
529
  input_audio_debug.stop_recording(
530
  fn=save_and_transcribe_audio,
@@ -539,12 +617,16 @@ def create_demo(tts_server: bool = False, model="llama3"):
539
  tts_enabled.change(
540
  fn=set_tts_enabled, inputs=[tts_enabled, state], outputs=[state]
541
  )
 
 
 
542
  llm_backend.change(
543
  fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
544
  )
545
  enable_history.change(
546
  fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
547
  )
 
548
 
549
  return demo
550
 
 
1
  import time
2
+
3
  import gradio as gr
4
  import numpy as np
5
+ import ollama
6
  import torch
7
  import torchaudio
 
8
  import typer
 
 
 
 
 
 
 
9
  from langchain.memory import ChatMessageHistory
 
10
  from langchain.tools import tool
11
+ from langchain.tools.base import StructuredTool
12
+ from langchain_core.utils.function_calling import convert_to_openai_tool
13
  from loguru import logger
14
+ from transformers import pipeline
15
 
16
+ from kitt.core import tts_gradio
17
+ from kitt.core import utils as kitt_utils
18
+ from kitt.core import voice_options
19
 
20
+ # from kitt.core.model import process_query
21
+ from kitt.core.model import generate_function_call as process_query
22
+ from kitt.core.tts import run_melo_tts, run_tts_fast, run_tts_replicate
23
  from kitt.skills import (
24
+ code_interpreter,
25
+ date_time_info,
26
+ do_anything_else,
27
+ extract_func_args,
28
  find_route,
29
  get_forecast,
30
+ get_weather,
31
+ get_weather_current_location,
 
32
  search_along_route_w_coordinates,
33
+ search_points_of_interest,
34
  set_vehicle_destination,
35
+ set_vehicle_speed,
 
 
 
36
  )
37
+ from kitt.skills import vehicle_status as vehicle_status_fn
38
+ from kitt.skills.common import config, vehicle
39
+ from kitt.skills.routing import calculate_route, find_address
 
 
 
40
 
41
+ ORIGIN = "Mondorf-les-Bains, Luxembourg"
42
+ DESTINATION = "Rue Alphonse Weicker, Luxembourg"
43
+ DEFAULT_LLM_BACKEND = "ollama"
44
+ ENABLE_HISTORY = True
45
+ ENABLE_TTS = True
46
+ TTS_BACKEND = "local"
47
+ USER_PREFERENCES = "User loves italian food."
48
 
49
  global_context = {
50
  "vehicle": vehicle,
51
  "query": "How is the weather?",
52
  "route_points": [],
53
+ "origin": ORIGIN,
54
+ "destination": DESTINATION,
55
+ "enable_history": ENABLE_HISTORY,
56
+ "tts_enabled": ENABLE_TTS,
57
+ "tts_backend": TTS_BACKEND,
58
+ "llm_backend": DEFAULT_LLM_BACKEND,
59
+ "map_origin": ORIGIN,
60
+ "map_destination": DESTINATION,
61
+ "update_proxy": 0,
62
+ "map": None,
63
  }
64
 
65
  speaker_embedding_cache = {}
 
88
  <|im_start|>assistant
89
  """
90
 
 
 
91
 
92
  def get_prompt(template, input, history, tools):
93
  # "vehicle_status": vehicle_status_fn()[0]
 
235
  if state["tts_enabled"]:
236
  # voice_out = run_tts_replicate(output_text, voice_character)
237
  # voice_out = run_tts_fast(output_text)[0]
238
+ voice_out = run_melo_tts(output_text, voice_character)
239
  # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
240
  return (
241
  output_text,
 
259
 
260
  if not state["enable_history"]:
261
  history.clear()
262
+ global_context["update_proxy"] += 1
263
+
264
+ return (
265
+ text,
266
+ voice,
267
+ vehicle.model_dump_json(),
268
+ state,
269
+ dict(update_proxy=global_context["update_proxy"]),
270
+ )
271
 
272
 
273
  def calculate_route_gradio(origin, destination):
274
  vehicle_status, points = calculate_route(origin, destination)
275
  plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
276
  global_context["route_points"] = points
277
+ # state.value["route_points"] = points
278
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
279
  return plot, vehicle_status, 0
280
 
281
 
282
+ def update_vehicle_status(trip_progress, origin, destination, state):
283
  if not global_context["route_points"]:
284
  vehicle_status, points = calculate_route(origin, destination)
285
  global_context["route_points"] = points
286
+ global_context["destination"] = destination
287
+ global_context["route_points"] = global_context["route_points"]
288
  n_points = len(global_context["route_points"])
289
  index = min(int(trip_progress / 100 * n_points), n_points - 1)
290
+ logger.info(f"Trip progress: {trip_progress} len: {n_points}, index: {index}")
291
  new_coords = global_context["route_points"][index]
292
  new_coords = new_coords["latitude"], new_coords["longitude"]
293
+ logger.info(
294
+ f"Trip progress: {trip_progress}, len: {n_points}, new_coords: {new_coords}"
295
+ )
296
  vehicle.location_coordinates = new_coords
297
+ new_vehicle_location = find_address(new_coords[0], new_coords[1])
298
+ vehicle.location = new_vehicle_location
299
  plot = kitt_utils.plot_route(
300
  global_context["route_points"], vehicle=vehicle.location_coordinates
301
  )
302
+ return vehicle.model_dump_json(), plot, state
303
 
304
 
305
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
342
 
343
  def save_and_transcribe_run_model(audio, voice_character, state):
344
  text = save_and_transcribe_audio(audio)
345
+ out_text, out_voice, vehicle_status, state, update_proxy = run_model(
346
+ text, voice_character, state
347
+ )
348
+ return None, text, out_text, out_voice, vehicle_status, state, update_proxy
349
 
350
 
351
  def set_tts_enabled(tts_enabled, state):
 
354
  f"TTS enabled was {state['tts_enabled']} and changed to {new_tts_enabled}"
355
  )
356
  state["tts_enabled"] = new_tts_enabled
357
+ global_context["tts_enabled"] = new_tts_enabled
358
  return state
359
 
360
 
 
364
  f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
365
  )
366
  state["llm_backend"] = new_llm_backend
367
+ global_context["llm_backend"] = new_llm_backend
368
  return state
369
 
370
 
 
372
  new_preferences = preferences
373
  logger.info(f"User preferences changed to: {new_preferences}")
374
  state["user_preferences"] = new_preferences
375
+ global_context["user_preferences"] = new_preferences
376
  return state
377
 
378
 
 
382
  f"Enable history was {state['enable_history']} and changed to {new_enable_history}"
383
  )
384
  state["enable_history"] = new_enable_history
385
+ global_context["enable_history"] = new_enable_history
386
+ return state
387
+
388
+
389
+ def set_tts_backend(tts_backend, state):
390
+ new_tts_backend = tts_backend.lower()
391
+ logger.info(
392
+ f"TTS backend was {state['tts_backend']} and changed to {new_tts_backend}"
393
+ )
394
+ state["tts_backend"] = new_tts_backend
395
+ global_context["tts_backend"] = new_tts_backend
396
  return state
397
 
398
 
399
+ def conditional_update():
400
+ if global_context["destination"] != vehicle.destination:
401
+ global_context["destination"] = vehicle.destination
402
+
403
+ if global_context["origin"] != vehicle.location:
404
+ global_context["origin"] = vehicle.location
405
+
406
+ if (
407
+ global_context["map_origin"] != vehicle.location
408
+ or global_context["map_destination"] != vehicle.destination
409
+ or global_context["update_proxy"] == 0
410
+ ):
411
+ logger.info(f"Updating the map plot... in conditional_update")
412
+ map_plot, vehicle_status, _ = calculate_route_gradio(
413
+ vehicle.location, vehicle.destination
414
+ )
415
+ global_context["map"] = map_plot
416
+ return global_context["map"]
417
+
418
+
419
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
420
  # in "Insecure origins treated as secure", enable it and relaunch chrome
421
 
 
424
  # What's the closest restaurant from here?
425
 
426
 
 
 
 
 
 
 
 
427
  def create_demo(tts_server: bool = False, model="llama3"):
428
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
429
  with gr.Blocks(theme=gr.themes.Default()) as demo:
 
437
  "llm_backend": DEFAULT_LLM_BACKEND,
438
  "user_preferences": USER_PREFERENCES,
439
  "enable_history": ENABLE_HISTORY,
440
+ "tts_backend": TTS_BACKEND,
441
+ "destination": DESTINATION,
442
  }
443
  )
444
+
445
  plot, vehicle_status, _ = calculate_route_gradio(ORIGIN, DESTINATION)
446
+ global_context["map"] = plot
447
 
448
  with gr.Row():
449
  with gr.Column(scale=1, min_width=300):
 
512
  label="Input text",
513
  interactive=True,
514
  )
515
+ update_proxy = gr.JSON(
516
+ value=dict(update_proxy=0),
517
+ label="Global context",
518
+ )
519
  vehicle_status = gr.JSON(
520
  value=vehicle.model_dump_json(), label="Vehicle status"
521
  )
 
526
  value="Yes" if ENABLE_TTS else "No",
527
  interactive=True,
528
  )
529
+ tts_backend = gr.Radio(
530
+ ["Local", "Replicate"],
531
+ label="TTS Backend",
532
+ value=TTS_BACKEND.title(),
533
+ interactive=True,
534
+ )
535
  llm_backend = gr.Radio(
536
  choices=["Ollama", "Replicate"],
537
  label="LLM Backend",
 
575
  input_text.submit(
576
  fn=run_model,
577
  inputs=[input_text, voice_character, state],
578
+ outputs=[output_text, output_audio, vehicle_status, state, update_proxy],
579
  )
580
  input_text_debug.submit(
581
  fn=run_model,
582
  inputs=[input_text_debug, voice_character, state],
583
+ outputs=[output_text, output_audio, vehicle_status, state, update_proxy],
584
  )
585
 
586
  # Set the vehicle status based on the trip progress
587
  trip_progress.release(
588
  fn=update_vehicle_status,
589
+ inputs=[trip_progress, origin, destination, state],
590
+ outputs=[vehicle_status, map_plot, state],
591
  )
592
 
593
  # Save and transcribe the audio
594
  input_audio.stop_recording(
595
  fn=save_and_transcribe_run_model,
596
  inputs=[input_audio, voice_character, state],
597
+ outputs=[
598
+ input_audio,
599
+ input_text,
600
+ output_text,
601
+ output_audio,
602
+ vehicle_status,
603
+ state,
604
+ update_proxy,
605
+ ],
606
  )
607
  input_audio_debug.stop_recording(
608
  fn=save_and_transcribe_audio,
 
617
  tts_enabled.change(
618
  fn=set_tts_enabled, inputs=[tts_enabled, state], outputs=[state]
619
  )
620
+ tts_backend.change(
621
+ fn=set_tts_backend, inputs=[tts_backend, state], outputs=[state]
622
+ )
623
  llm_backend.change(
624
  fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
625
  )
626
  enable_history.change(
627
  fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
628
  )
629
+ update_proxy.change(fn=conditional_update, inputs=[], outputs=[map_plot])
630
 
631
  return demo
632