Refactor save_and_transcribe_audio function for better code organization
Browse files- kitt/core/stt.py +11 -5
- space.py +3 -3
kitt/core/stt.py
CHANGED
@@ -22,12 +22,9 @@ def save_audio_as_wav(data, sample_rate, file_path):
|
|
22 |
)
|
23 |
|
24 |
|
25 |
-
def
|
26 |
sample_rate, data = audio
|
27 |
try:
|
28 |
-
# add timestamp to file name
|
29 |
-
filename = f"recordings/audio{time.time()}.wav"
|
30 |
-
save_audio_as_wav(data, sample_rate, filename)
|
31 |
data = data.astype(np.float32)
|
32 |
data /= np.max(np.abs(data))
|
33 |
text = transcriber({"sampling_rate": sample_rate, "raw": data})["text"]
|
@@ -36,4 +33,13 @@ def save_and_transcribe_audio(audio):
|
|
36 |
except Exception as e:
|
37 |
logger.error(f"Error: {e}")
|
38 |
raise Exception("Error transcribing audio.")
|
39 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
+
def transcribe_audio(audio):
|
26 |
sample_rate, data = audio
|
27 |
try:
|
|
|
|
|
|
|
28 |
data = data.astype(np.float32)
|
29 |
data /= np.max(np.abs(data))
|
30 |
text = transcriber({"sampling_rate": sample_rate, "raw": data})["text"]
|
|
|
33 |
except Exception as e:
|
34 |
logger.error(f"Error: {e}")
|
35 |
raise Exception("Error transcribing audio.")
|
36 |
+
return text
|
37 |
+
|
38 |
+
|
39 |
+
def save_and_transcribe_audio(audio, save=True):
|
40 |
+
sample_rate, data = audio
|
41 |
+
# add timestamp to file name
|
42 |
+
filename = f"recordings/audio{time.time()}.wav"
|
43 |
+
if save:
|
44 |
+
save_audio_as_wav(data, sample_rate, filename)
|
45 |
+
return transcribe_audio(audio)
|
space.py
CHANGED
@@ -9,7 +9,7 @@ from kitt.core import tts_gradio
|
|
9 |
from kitt.core import utils as kitt_utils
|
10 |
from kitt.core import voice_options
|
11 |
from kitt.core.model import generate_function_call as process_query
|
12 |
-
from kitt.core.stt import
|
13 |
from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
|
14 |
from kitt.skills import (
|
15 |
code_interpreter,
|
@@ -182,7 +182,7 @@ def update_vehicle_status(trip_progress, origin, destination, state):
|
|
182 |
|
183 |
|
184 |
def save_and_transcribe_run_model(audio, voice_character, state):
|
185 |
-
text =
|
186 |
out_text, out_voice, vehicle_status, state, update_proxy = run_model(
|
187 |
text, voice_character, state
|
188 |
)
|
@@ -452,7 +452,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
452 |
],
|
453 |
)
|
454 |
input_audio_debug.stop_recording(
|
455 |
-
fn=
|
456 |
inputs=[input_audio_debug],
|
457 |
outputs=[input_text_debug],
|
458 |
)
|
|
|
9 |
from kitt.core import utils as kitt_utils
|
10 |
from kitt.core import voice_options
|
11 |
from kitt.core.model import generate_function_call as process_query
|
12 |
+
from kitt.core.stt import transcribe_audio
|
13 |
from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
|
14 |
from kitt.skills import (
|
15 |
code_interpreter,
|
|
|
182 |
|
183 |
|
184 |
def save_and_transcribe_run_model(audio, voice_character, state):
|
185 |
+
text = transcribe_audio(audio)
|
186 |
out_text, out_voice, vehicle_status, state, update_proxy = run_model(
|
187 |
text, voice_character, state
|
188 |
)
|
|
|
452 |
],
|
453 |
)
|
454 |
input_audio_debug.stop_recording(
|
455 |
+
fn=transcribe_audio,
|
456 |
inputs=[input_audio_debug],
|
457 |
outputs=[input_text_debug],
|
458 |
)
|