freddyaboulton HF staff commited on
Commit
645d699
1 Parent(s): e9633ca
Files changed (1) hide show
  1. app.py +36 -46
app.py CHANGED
@@ -7,6 +7,7 @@ import openai
7
  import time
8
  import base64
9
 
 
10
  def create_client(api_key):
11
  return openai.OpenAI(
12
  base_url="https://llama3-1-8b.lepton.run/api/v1/",
@@ -24,7 +25,8 @@ def update_or_append_conversation(conversation, id, role, content):
24
  conversation.append({"id": id, "role": role, "content": content})
25
 
26
 
27
- def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[str], client: OpenAI, output_format: str):
 
28
  if client is None:
29
  raise gr.Error("Please enter a valid API key first.")
30
 
@@ -32,7 +34,7 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
32
  audio_data = base64.b64encode(audio_bytes).decode()
33
 
34
  try:
35
- stream = state.client.chat.completions.create(
36
  extra_body={
37
  "require_audio": True,
38
  "tts_preset_id": "jessica",
@@ -82,7 +84,7 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
82
  raise gr.Error(f"Error during audio streaming: {e}")
83
 
84
  def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
85
- gradio_conversation: list[dict], client: OpenAI, output_format: str):
86
 
87
  audio_buffer = io.BytesIO()
88
  segment = AudioSegment(
@@ -93,7 +95,7 @@ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
93
  )
94
  segment.export(audio_buffer, format="wav")
95
 
96
- generator = generate_response_and_audio(audio_buffer.getvalue(), state)
97
 
98
  for id, text, asr, audio in generator:
99
  if asr:
@@ -107,53 +109,41 @@ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
107
  else:
108
  yield AdditionalOutputs(lepton_conversation, gradio_conversation)
109
 
 
 
 
 
 
110
 
111
- with gr.Blocks() as demo:
112
- with gr.Row():
113
- api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
114
- set_key_button = gr.Button("Set API Key")
115
-
116
- api_key_status = gr.Textbox(label="API Key Status", interactive=False)
117
-
118
- with gr.Row():
119
- format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
120
 
 
121
  with gr.Row():
122
- with gr.Column():
123
- input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
124
- with gr.Column():
 
 
 
 
 
 
125
  chatbot = gr.Chatbot(label="Conversation", type="messages")
126
- output_audio = gr.Audio(label="Output Audio", autoplay=True)
127
-
128
- state = gr.State(AppState())
129
-
130
- set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
131
- format_dropdown.change(update_format, inputs=[format_dropdown, state], outputs=[state])
132
-
133
- stream = input_audio.stream(
134
- process_audio,
135
- [input_audio, state],
136
- [input_audio, state],
137
- stream_every=0.25, # Reduced to make it more responsive
138
- time_limit=60, # Increased to allow for longer messages
139
- )
140
-
141
- stream.then(
142
- maybe_call_response,
143
- inputs=[state],
144
- outputs=[chatbot, output_audio, state],
145
- )
146
 
147
- # Automatically restart recording after the assistant's response
148
- restart = output_audio.change(
149
- start_recording_user,
150
- [state],
151
- [input_audio]
152
  )
 
153
 
154
- # Add a "Stop Conversation" button
155
- cancel = gr.Button("Stop Conversation", variant="stop")
156
- cancel.click(lambda: (AppState(stopped=True), gr.update(recording=False)), None,
157
- [state, input_audio], cancels=[stream, restart])
158
-
159
  demo.launch()
 
7
  import time
8
  import base64
9
 
10
+
11
  def create_client(api_key):
12
  return openai.OpenAI(
13
  base_url="https://llama3-1-8b.lepton.run/api/v1/",
 
25
  conversation.append({"id": id, "role": role, "content": content})
26
 
27
 
28
+ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[dict],
29
+ client: openai.OpenAI, output_format: str):
30
  if client is None:
31
  raise gr.Error("Please enter a valid API key first.")
32
 
 
34
  audio_data = base64.b64encode(audio_bytes).decode()
35
 
36
  try:
37
+ stream = client.chat.completions.create(
38
  extra_body={
39
  "require_audio": True,
40
  "tts_preset_id": "jessica",
 
84
  raise gr.Error(f"Error during audio streaming: {e}")
85
 
86
  def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
87
+ gradio_conversation: list[dict], client: openai.OpenAI, output_format: str):
88
 
89
  audio_buffer = io.BytesIO()
90
  segment = AudioSegment(
 
95
  )
96
  segment.export(audio_buffer, format="wav")
97
 
98
+ generator = generate_response_and_audio(audio_buffer.getvalue(), lepton_conversation, client, output_format)
99
 
100
  for id, text, asr, audio in generator:
101
  if asr:
 
109
  else:
110
  yield AdditionalOutputs(lepton_conversation, gradio_conversation)
111
 
112
+ def set_api_key(api_key):
113
+ if not api_key:
114
+ raise gr.Error("Please enter a valid API key.")
115
+ client = create_client(api_key)
116
+ return client
117
 
 
 
 
 
 
 
 
 
 
118
 
119
+ with gr.Blocks() as demo:
120
  with gr.Row():
121
+ with gr.Group():
122
+ with gr.Column():
123
+ api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
124
+ api_key_status = gr.Textbox(label="API Key Status", interactive=False)
125
+ with gr.Column():
126
+ set_key_button = gr.Button("Set API Key")
127
+
128
+ with gr.Group():
129
+ with gr.Row():
130
  chatbot = gr.Chatbot(label="Conversation", type="messages")
131
+ with gr.Row():
132
+ with gr.Column():
133
+ format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
134
+ with gr.Column():
135
+ audio = WebRTC(modality="audio", mode="send-receive",
136
+ label="Audio Stream")
137
+
138
+ client_state = gr.State(None)
139
+ lepton_conversation = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ audio.stream(
142
+ ReplyOnPause(response),
143
+ inputs=[audio, lepton_conversation, chatbot, client_state, format_dropdown],
144
+ outputs=[audio]
 
145
  )
146
+ audio.on_additional_outputs(lambda l, g: (l, g), outputs=[lepton_conversation, chatbot])
147
 
148
+ if __name__ == "__main__":
 
 
 
 
149
  demo.launch()