Commit
•
645d699
1
Parent(s):
e9633ca
modify
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import openai
|
|
7 |
import time
|
8 |
import base64
|
9 |
|
|
|
10 |
def create_client(api_key):
|
11 |
return openai.OpenAI(
|
12 |
base_url="https://llama3-1-8b.lepton.run/api/v1/",
|
@@ -24,7 +25,8 @@ def update_or_append_conversation(conversation, id, role, content):
|
|
24 |
conversation.append({"id": id, "role": role, "content": content})
|
25 |
|
26 |
|
27 |
-
def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[
|
|
|
28 |
if client is None:
|
29 |
raise gr.Error("Please enter a valid API key first.")
|
30 |
|
@@ -32,7 +34,7 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
|
|
32 |
audio_data = base64.b64encode(audio_bytes).decode()
|
33 |
|
34 |
try:
|
35 |
-
stream =
|
36 |
extra_body={
|
37 |
"require_audio": True,
|
38 |
"tts_preset_id": "jessica",
|
@@ -82,7 +84,7 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
|
|
82 |
raise gr.Error(f"Error during audio streaming: {e}")
|
83 |
|
84 |
def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
|
85 |
-
gradio_conversation: list[dict], client: OpenAI, output_format: str):
|
86 |
|
87 |
audio_buffer = io.BytesIO()
|
88 |
segment = AudioSegment(
|
@@ -93,7 +95,7 @@ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
|
|
93 |
)
|
94 |
segment.export(audio_buffer, format="wav")
|
95 |
|
96 |
-
generator = generate_response_and_audio(audio_buffer.getvalue(),
|
97 |
|
98 |
for id, text, asr, audio in generator:
|
99 |
if asr:
|
@@ -107,53 +109,41 @@ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
|
|
107 |
else:
|
108 |
yield AdditionalOutputs(lepton_conversation, gradio_conversation)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
with gr.Blocks() as demo:
|
112 |
-
with gr.Row():
|
113 |
-
api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
|
114 |
-
set_key_button = gr.Button("Set API Key")
|
115 |
-
|
116 |
-
api_key_status = gr.Textbox(label="API Key Status", interactive=False)
|
117 |
-
|
118 |
-
with gr.Row():
|
119 |
-
format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
|
120 |
|
|
|
121 |
with gr.Row():
|
122 |
-
with gr.
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
chatbot = gr.Chatbot(label="Conversation", type="messages")
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
[input_audio, state],
|
136 |
-
[input_audio, state],
|
137 |
-
stream_every=0.25, # Reduced to make it more responsive
|
138 |
-
time_limit=60, # Increased to allow for longer messages
|
139 |
-
)
|
140 |
-
|
141 |
-
stream.then(
|
142 |
-
maybe_call_response,
|
143 |
-
inputs=[state],
|
144 |
-
outputs=[chatbot, output_audio, state],
|
145 |
-
)
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
[
|
151 |
-
[input_audio]
|
152 |
)
|
|
|
153 |
|
154 |
-
|
155 |
-
cancel = gr.Button("Stop Conversation", variant="stop")
|
156 |
-
cancel.click(lambda: (AppState(stopped=True), gr.update(recording=False)), None,
|
157 |
-
[state, input_audio], cancels=[stream, restart])
|
158 |
-
|
159 |
demo.launch()
|
|
|
7 |
import time
|
8 |
import base64
|
9 |
|
10 |
+
|
11 |
def create_client(api_key):
|
12 |
return openai.OpenAI(
|
13 |
base_url="https://llama3-1-8b.lepton.run/api/v1/",
|
|
|
25 |
conversation.append({"id": id, "role": role, "content": content})
|
26 |
|
27 |
|
28 |
+
def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[dict],
|
29 |
+
client: openai.OpenAI, output_format: str):
|
30 |
if client is None:
|
31 |
raise gr.Error("Please enter a valid API key first.")
|
32 |
|
|
|
34 |
audio_data = base64.b64encode(audio_bytes).decode()
|
35 |
|
36 |
try:
|
37 |
+
stream = client.chat.completions.create(
|
38 |
extra_body={
|
39 |
"require_audio": True,
|
40 |
"tts_preset_id": "jessica",
|
|
|
84 |
raise gr.Error(f"Error during audio streaming: {e}")
|
85 |
|
86 |
def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
|
87 |
+
gradio_conversation: list[dict], client: openai.OpenAI, output_format: str):
|
88 |
|
89 |
audio_buffer = io.BytesIO()
|
90 |
segment = AudioSegment(
|
|
|
95 |
)
|
96 |
segment.export(audio_buffer, format="wav")
|
97 |
|
98 |
+
generator = generate_response_and_audio(audio_buffer.getvalue(), lepton_conversation, client, output_format)
|
99 |
|
100 |
for id, text, asr, audio in generator:
|
101 |
if asr:
|
|
|
109 |
else:
|
110 |
yield AdditionalOutputs(lepton_conversation, gradio_conversation)
|
111 |
|
112 |
+
def set_api_key(api_key):
|
113 |
+
if not api_key:
|
114 |
+
raise gr.Error("Please enter a valid API key.")
|
115 |
+
client = create_client(api_key)
|
116 |
+
return client
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
with gr.Blocks() as demo:
|
120 |
with gr.Row():
|
121 |
+
with gr.Group():
|
122 |
+
with gr.Column():
|
123 |
+
api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
|
124 |
+
api_key_status = gr.Textbox(label="API Key Status", interactive=False)
|
125 |
+
with gr.Column():
|
126 |
+
set_key_button = gr.Button("Set API Key")
|
127 |
+
|
128 |
+
with gr.Group():
|
129 |
+
with gr.Row():
|
130 |
chatbot = gr.Chatbot(label="Conversation", type="messages")
|
131 |
+
with gr.Row():
|
132 |
+
with gr.Column():
|
133 |
+
format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
|
134 |
+
with gr.Column():
|
135 |
+
audio = WebRTC(modality="audio", mode="send-receive",
|
136 |
+
label="Audio Stream")
|
137 |
+
|
138 |
+
client_state = gr.State(None)
|
139 |
+
lepton_conversation = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
audio.stream(
|
142 |
+
ReplyOnPause(response),
|
143 |
+
inputs=[audio, lepton_conversation, chatbot, client_state, format_dropdown],
|
144 |
+
outputs=[audio]
|
|
|
145 |
)
|
146 |
+
audio.on_additional_outputs(lambda l, g: (l, g), outputs=[lepton_conversation, chatbot])
|
147 |
|
148 |
+
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
149 |
demo.launch()
|