sasan commited on
Commit
0950a4c
·
1 Parent(s): bd0f899

Sort out vehicle status

Browse files
Files changed (3) hide show
  1. kitt/core/model.py +1 -1
  2. kitt/skills/routing.py +1 -1
  3. main.py +14 -18
kitt/core/model.py CHANGED
@@ -331,7 +331,7 @@ def run_inference_replicate(prompt):
331
  )
332
  out = "".join(output)
333
 
334
- logger.debug(f"Response from Ollama:\nOut:{out}")
335
 
336
  return out
337
 
 
331
  )
332
  out = "".join(output)
333
 
334
+ logger.debug(f"Response from Replicate:\nOut:{out}")
335
 
336
  return out
337
 
kitt/skills/routing.py CHANGED
@@ -59,7 +59,7 @@ def calculate_route(origin, destination):
59
  data = response.json()
60
  points = data["routes"][0]["legs"][0]["points"]
61
 
62
- return vehicle.model_dump_json(), points
63
 
64
 
65
  def find_route_tomtom(
 
59
  data = response.json()
60
  points = data["routes"][0]["legs"][0]["points"]
61
 
62
+ return vehicle, points
63
 
64
 
65
  def find_route_tomtom(
main.py CHANGED
@@ -133,11 +133,7 @@ def search_along_route(query=""):
133
 
134
  def set_time(time_picker):
135
  vehicle.time = time_picker
136
- return vehicle.model_dump_json(indent=2)
137
-
138
-
139
- def get_vehicle_status(state):
140
- return state.value["vehicle"].model_dump_json(indent=2)
141
 
142
 
143
  tools = [
@@ -238,10 +234,12 @@ def run_llama3_model(query, voice_character, state):
238
  elif global_context["tts_backend"] == "replicate":
239
  voice_out = run_tts_replicate(output_text, voice_character)
240
  else:
241
- voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
242
- #
 
 
243
  # voice_out = run_tts_fast(output_text)[0]
244
- #
245
  return (
246
  output_text,
247
  voice_out,
@@ -269,24 +267,24 @@ def run_model(query, voice_character, state):
269
  return (
270
  text,
271
  voice,
272
- vehicle,
273
  state,
274
  dict(update_proxy=global_context["update_proxy"]),
275
  )
276
 
277
 
278
  def calculate_route_gradio(origin, destination):
279
- vehicle_status, points = calculate_route(origin, destination)
280
  plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
281
  global_context["route_points"] = points
282
  # state.value["route_points"] = points
283
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
284
- return plot, vehicle_status, 0
285
 
286
 
287
  def update_vehicle_status(trip_progress, origin, destination, state):
288
  if not global_context["route_points"]:
289
- vehicle_status, points = calculate_route(origin, destination)
290
  global_context["route_points"] = points
291
  global_context["destination"] = destination
292
  global_context["route_points"] = global_context["route_points"]
@@ -305,7 +303,6 @@ def update_vehicle_status(trip_progress, origin, destination, state):
305
  global_context["route_points"], vehicle=vehicle.location_coordinates
306
  )
307
  return vehicle, plot, state
308
- return vehicle.model_dump_json(indent=2), plot, state
309
 
310
 
311
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -415,9 +412,7 @@ def conditional_update():
415
  or global_context["update_proxy"] == 0
416
  ):
417
  logger.info(f"Updating the map plot... in conditional_update")
418
- map_plot, vehicle_status, _ = calculate_route_gradio(
419
- vehicle.location, vehicle.destination
420
- )
421
  global_context["map"] = map_plot
422
  return global_context["map"]
423
 
@@ -448,13 +443,13 @@ def create_demo(tts_server: bool = False, model="llama3"):
448
  }
449
  )
450
 
451
- plot, vehicle_status, _ = calculate_route_gradio(ORIGIN, DESTINATION)
452
  global_context["map"] = plot
453
 
454
  with gr.Row():
455
  with gr.Column(scale=1, min_width=300):
456
  vehicle_status = gr.JSON(
457
- value=vehicle.model_dump_json(indent=2), label="Vehicle status"
458
  )
459
  time_picker = gr.Dropdown(
460
  choices=hour_options,
@@ -649,6 +644,7 @@ demo.launch(
649
  ssl_verify=False,
650
  share=False,
651
  )
 
652
  app = typer.Typer()
653
 
654
 
 
133
 
134
  def set_time(time_picker):
135
  vehicle.time = time_picker
136
+ return vehicle
 
 
 
 
137
 
138
 
139
  tools = [
 
234
  elif global_context["tts_backend"] == "replicate":
235
  voice_out = run_tts_replicate(output_text, voice_character)
236
  else:
237
+ voice_out = tts_gradio(
238
+ output_text, voice_character, speaker_embedding_cache
239
+ )[0]
240
+ #
241
  # voice_out = run_tts_fast(output_text)[0]
242
+ #
243
  return (
244
  output_text,
245
  voice_out,
 
267
  return (
268
  text,
269
  voice,
270
+ vehicle.model_dump(),
271
  state,
272
  dict(update_proxy=global_context["update_proxy"]),
273
  )
274
 
275
 
276
  def calculate_route_gradio(origin, destination):
277
+ _, points = calculate_route(origin, destination)
278
  plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
279
  global_context["route_points"] = points
280
  # state.value["route_points"] = points
281
  vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
282
+ return plot, vehicle.model_dump(), 0
283
 
284
 
285
  def update_vehicle_status(trip_progress, origin, destination, state):
286
  if not global_context["route_points"]:
287
+ _, points = calculate_route(origin, destination)
288
  global_context["route_points"] = points
289
  global_context["destination"] = destination
290
  global_context["route_points"] = global_context["route_points"]
 
303
  global_context["route_points"], vehicle=vehicle.location_coordinates
304
  )
305
  return vehicle, plot, state
 
306
 
307
 
308
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
412
  or global_context["update_proxy"] == 0
413
  ):
414
  logger.info(f"Updating the map plot... in conditional_update")
415
+ map_plot, _, _ = calculate_route_gradio(vehicle.location, vehicle.destination)
 
 
416
  global_context["map"] = map_plot
417
  return global_context["map"]
418
 
 
443
  }
444
  )
445
 
446
+ plot, _, _ = calculate_route_gradio(ORIGIN, DESTINATION)
447
  global_context["map"] = plot
448
 
449
  with gr.Row():
450
  with gr.Column(scale=1, min_width=300):
451
  vehicle_status = gr.JSON(
452
+ value=vehicle.model_dump(), label="Vehicle status"
453
  )
454
  time_picker = gr.Dropdown(
455
  choices=hour_options,
 
644
  ssl_verify=False,
645
  share=False,
646
  )
647
+
648
  app = typer.Typer()
649
 
650