freddyaboulton HF staff commited on
Commit
2558f9d
·
1 Parent(s): 2d5961d
Files changed (2) hide show
  1. app.py +10 -63
  2. requirements.txt +1 -3
app.py CHANGED
@@ -1,64 +1,18 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import io
4
  from pydub import AudioSegment
5
- import tempfile
6
  import openai
7
  import time
8
- from dataclasses import dataclass, field
9
- from threading import Lock
10
  import base64
11
 
12
- @dataclass
13
- class AppState:
14
- stream: np.ndarray | None = None
15
- sampling_rate: int = 0
16
- pause_detected: bool = False
17
- conversation: list = field(default_factory=list)
18
- client: openai.OpenAI = None
19
- output_format: str = "mp3"
20
- stopped: bool = False
21
-
22
- # Global lock for thread safety
23
- state_lock = Lock()
24
-
25
  def create_client(api_key):
26
  return openai.OpenAI(
27
  base_url="https://llama3-1-8b.lepton.run/api/v1/",
28
  api_key=api_key
29
  )
30
 
31
- def determine_pause(audio, sampling_rate, state):
32
- # Take the last 1 second of audio
33
- pause_length = int(sampling_rate * 1) # 1 second
34
- if len(audio) < pause_length:
35
- return False
36
- last_audio = audio[-pause_length:]
37
- amplitude = np.abs(last_audio)
38
-
39
- # Calculate the average amplitude in the last 1 second
40
- avg_amplitude = np.mean(amplitude)
41
- silence_threshold = 0.01 # Adjust this threshold as needed
42
- if avg_amplitude < silence_threshold:
43
- return True
44
- else:
45
- return False
46
-
47
- def process_audio(audio: tuple, state: AppState):
48
- if state.stream is None:
49
- state.stream = audio[1]
50
- state.sampling_rate = audio[0]
51
- else:
52
- state.stream = np.concatenate((state.stream, audio[1]))
53
-
54
- pause_detected = determine_pause(state.stream, state.sampling_rate, state)
55
- state.pause_detected = pause_detected
56
-
57
- if state.pause_detected:
58
- # Stop recording
59
- return gr.update(recording=False), state
60
- else:
61
- return None, state
62
 
63
  def update_or_append_conversation(conversation, id, role, content):
64
  # Find if there's an existing message with the given id
@@ -69,7 +23,8 @@ def update_or_append_conversation(conversation, id, role, content):
69
  # If not found, append a new message
70
  conversation.append({"id": id, "role": role, "content": content})
71
 
72
- def generate_response_and_audio(audio_bytes: bytes, state: AppState):
 
73
  if state.client is None:
74
  raise gr.Error("Please enter a valid API key first.")
75
 
@@ -123,19 +78,16 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
123
  except Exception as e:
124
  raise gr.Error(f"Error during audio streaming: {e}")
125
 
126
- def response(state: AppState):
127
- if not state.pause_detected:
128
- return gr.update(), gr.update(), state
129
-
130
- if state.stream is None or len(state.stream) == 0:
131
- return gr.update(), gr.update(), state
132
 
133
  audio_buffer = io.BytesIO()
134
  segment = AudioSegment(
135
- state.stream.tobytes(),
136
- frame_rate=state.sampling_rate,
137
- sample_width=state.stream.dtype.itemsize,
138
- channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
139
  )
140
  segment.export(audio_buffer, format="wav")
141
 
@@ -167,11 +119,6 @@ def start_recording_user(state: AppState):
167
  else:
168
  return gr.update(recording=False)
169
 
170
- def set_api_key(api_key, state):
171
- if not api_key:
172
- raise gr.Error("Please enter a valid API key.")
173
- state.client = create_client(api_key)
174
- return "API key set successfully!", state
175
 
176
  def update_format(format, state):
177
  state.output_format = format
 
1
  import gradio as gr
2
+ from gradio_webrtc import WebRTC, ReplyOnPause, AdditionalOutputs
3
  import numpy as np
4
  import io
5
  from pydub import AudioSegment
 
6
  import openai
7
  import time
 
 
8
  import base64
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def create_client(api_key):
11
  return openai.OpenAI(
12
  base_url="https://llama3-1-8b.lepton.run/api/v1/",
13
  api_key=api_key
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def update_or_append_conversation(conversation, id, role, content):
18
  # Find if there's an existing message with the given id
 
23
  # If not found, append a new message
24
  conversation.append({"id": id, "role": role, "content": content})
25
 
26
+
27
+ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[str], client: OpenAI, output_format):
28
  if state.client is None:
29
  raise gr.Error("Please enter a valid API key first.")
30
 
 
78
  except Exception as e:
79
  raise gr.Error(f"Error during audio streaming: {e}")
80
 
81
+ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
82
+ gradio_conversation: list[dict], client: OpenAI, output_format: str):
83
+
 
 
 
84
 
85
  audio_buffer = io.BytesIO()
86
  segment = AudioSegment(
87
+ audio[1].tobytes(),
88
+ frame_rate=audio[0],
89
+ sample_width=audio[1].dtype.itemsize,
90
+ channels=1,
91
  )
92
  segment.export(audio_buffer, format="wav")
93
 
 
119
  else:
120
  return gr.update(recording=False)
121
 
 
 
 
 
 
122
 
123
  def update_format(format, state):
124
  state.output_format = format
requirements.txt CHANGED
@@ -1,4 +1,2 @@
1
- https://gradio-builds.s3.amazonaws.com/cffe9a7ab7f71e76d7214dc57c6278ffaf5bcdf9/gradio-5.0.0b1-py3-none-any.whl
2
- numpy
3
- pydub
4
  openai
 
1
+ gradio_webrtc[vad]==0.0.11
 
 
2
  openai