Commit
·
2558f9d
1
Parent(s):
2d5961d
edit
Browse files- app.py +10 -63
- requirements.txt +1 -3
app.py
CHANGED
@@ -1,64 +1,18 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
import io
|
4 |
from pydub import AudioSegment
|
5 |
-
import tempfile
|
6 |
import openai
|
7 |
import time
|
8 |
-
from dataclasses import dataclass, field
|
9 |
-
from threading import Lock
|
10 |
import base64
|
11 |
|
12 |
-
@dataclass
|
13 |
-
class AppState:
|
14 |
-
stream: np.ndarray | None = None
|
15 |
-
sampling_rate: int = 0
|
16 |
-
pause_detected: bool = False
|
17 |
-
conversation: list = field(default_factory=list)
|
18 |
-
client: openai.OpenAI = None
|
19 |
-
output_format: str = "mp3"
|
20 |
-
stopped: bool = False
|
21 |
-
|
22 |
-
# Global lock for thread safety
|
23 |
-
state_lock = Lock()
|
24 |
-
|
25 |
def create_client(api_key):
|
26 |
return openai.OpenAI(
|
27 |
base_url="https://llama3-1-8b.lepton.run/api/v1/",
|
28 |
api_key=api_key
|
29 |
)
|
30 |
|
31 |
-
def determine_pause(audio, sampling_rate, state):
|
32 |
-
# Take the last 1 second of audio
|
33 |
-
pause_length = int(sampling_rate * 1) # 1 second
|
34 |
-
if len(audio) < pause_length:
|
35 |
-
return False
|
36 |
-
last_audio = audio[-pause_length:]
|
37 |
-
amplitude = np.abs(last_audio)
|
38 |
-
|
39 |
-
# Calculate the average amplitude in the last 1 second
|
40 |
-
avg_amplitude = np.mean(amplitude)
|
41 |
-
silence_threshold = 0.01 # Adjust this threshold as needed
|
42 |
-
if avg_amplitude < silence_threshold:
|
43 |
-
return True
|
44 |
-
else:
|
45 |
-
return False
|
46 |
-
|
47 |
-
def process_audio(audio: tuple, state: AppState):
|
48 |
-
if state.stream is None:
|
49 |
-
state.stream = audio[1]
|
50 |
-
state.sampling_rate = audio[0]
|
51 |
-
else:
|
52 |
-
state.stream = np.concatenate((state.stream, audio[1]))
|
53 |
-
|
54 |
-
pause_detected = determine_pause(state.stream, state.sampling_rate, state)
|
55 |
-
state.pause_detected = pause_detected
|
56 |
-
|
57 |
-
if state.pause_detected:
|
58 |
-
# Stop recording
|
59 |
-
return gr.update(recording=False), state
|
60 |
-
else:
|
61 |
-
return None, state
|
62 |
|
63 |
def update_or_append_conversation(conversation, id, role, content):
|
64 |
# Find if there's an existing message with the given id
|
@@ -69,7 +23,8 @@ def update_or_append_conversation(conversation, id, role, content):
|
|
69 |
# If not found, append a new message
|
70 |
conversation.append({"id": id, "role": role, "content": content})
|
71 |
|
72 |
-
|
|
|
73 |
if state.client is None:
|
74 |
raise gr.Error("Please enter a valid API key first.")
|
75 |
|
@@ -123,19 +78,16 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
|
|
123 |
except Exception as e:
|
124 |
raise gr.Error(f"Error during audio streaming: {e}")
|
125 |
|
126 |
-
def response(
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
if state.stream is None or len(state.stream) == 0:
|
131 |
-
return gr.update(), gr.update(), state
|
132 |
|
133 |
audio_buffer = io.BytesIO()
|
134 |
segment = AudioSegment(
|
135 |
-
|
136 |
-
frame_rate=
|
137 |
-
sample_width=
|
138 |
-
channels=
|
139 |
)
|
140 |
segment.export(audio_buffer, format="wav")
|
141 |
|
@@ -167,11 +119,6 @@ def start_recording_user(state: AppState):
|
|
167 |
else:
|
168 |
return gr.update(recording=False)
|
169 |
|
170 |
-
def set_api_key(api_key, state):
|
171 |
-
if not api_key:
|
172 |
-
raise gr.Error("Please enter a valid API key.")
|
173 |
-
state.client = create_client(api_key)
|
174 |
-
return "API key set successfully!", state
|
175 |
|
176 |
def update_format(format, state):
|
177 |
state.output_format = format
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio_webrtc import WebRTC, ReplyOnPause, AdditionalOutputs
|
3 |
import numpy as np
|
4 |
import io
|
5 |
from pydub import AudioSegment
|
|
|
6 |
import openai
|
7 |
import time
|
|
|
|
|
8 |
import base64
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def create_client(api_key):
|
11 |
return openai.OpenAI(
|
12 |
base_url="https://llama3-1-8b.lepton.run/api/v1/",
|
13 |
api_key=api_key
|
14 |
)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def update_or_append_conversation(conversation, id, role, content):
|
18 |
# Find if there's an existing message with the given id
|
|
|
23 |
# If not found, append a new message
|
24 |
conversation.append({"id": id, "role": role, "content": content})
|
25 |
|
26 |
+
|
27 |
+
def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[str], client: OpenAI, output_format):
|
28 |
if state.client is None:
|
29 |
raise gr.Error("Please enter a valid API key first.")
|
30 |
|
|
|
78 |
except Exception as e:
|
79 |
raise gr.Error(f"Error during audio streaming: {e}")
|
80 |
|
81 |
+
def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
|
82 |
+
gradio_conversation: list[dict], client: OpenAI, output_format: str):
|
83 |
+
|
|
|
|
|
|
|
84 |
|
85 |
audio_buffer = io.BytesIO()
|
86 |
segment = AudioSegment(
|
87 |
+
audio[1].tobytes(),
|
88 |
+
frame_rate=audio[0],
|
89 |
+
sample_width=audio[1].dtype.itemsize,
|
90 |
+
channels=1,
|
91 |
)
|
92 |
segment.export(audio_buffer, format="wav")
|
93 |
|
|
|
119 |
else:
|
120 |
return gr.update(recording=False)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def update_format(format, state):
|
124 |
state.output_format = format
|
requirements.txt
CHANGED
@@ -1,4 +1,2 @@
|
|
1 |
-
|
2 |
-
numpy
|
3 |
-
pydub
|
4 |
openai
|
|
|
1 |
+
gradio_webrtc[vad]==0.0.11
|
|
|
|
|
2 |
openai
|