freddyaboulton HF staff commited on
Commit
694882d
·
1 Parent(s): 2558f9d

more edits

Browse files
Files changed (1) hide show
  1. app.py +16 -38
app.py CHANGED
@@ -24,12 +24,11 @@ def update_or_append_conversation(conversation, id, role, content):
24
  conversation.append({"id": id, "role": role, "content": content})
25
 
26
 
27
- def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[str], client: OpenAI, output_format):
28
- if state.client is None:
29
  raise gr.Error("Please enter a valid API key first.")
30
 
31
- format_ = state.output_format
32
- bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
33
  audio_data = base64.b64encode(audio_bytes).decode()
34
 
35
  try:
@@ -41,7 +40,7 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
41
  "tts_audio_bitrate": bitrate
42
  },
43
  model="llama3.1-8b",
44
- messages=state.conversation + [{"role": "user", "content": [{"type": "audio", "data": audio_data}]}],
45
  temperature=0.7,
46
  max_tokens=256,
47
  stream=True,
@@ -62,18 +61,18 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
62
 
63
  if asr_results:
64
  asr_result += "".join(asr_results)
65
- yield id, None, asr_result, None, state
66
 
67
  if content:
68
  full_response += content
69
- yield id, full_response, None, None, state
70
 
71
  if audio:
72
  # Accumulate audio bytes and yield them
73
  audio_bytes_accumulated += b''.join([base64.b64decode(a) for a in audio])
74
- yield id, None, None, audio_bytes_accumulated, state
75
 
76
- yield id, full_response, asr_result, audio_bytes_accumulated, state
77
 
78
  except Exception as e:
79
  raise gr.Error(f"Error during audio streaming: {e}")
@@ -81,7 +80,6 @@ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[st
81
  def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
82
  gradio_conversation: list[dict], client: OpenAI, output_format: str):
83
 
84
-
85
  audio_buffer = io.BytesIO()
86
  segment = AudioSegment(
87
  audio[1].tobytes(),
@@ -93,36 +91,16 @@ def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
93
 
94
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
95
 
96
- for id, text, asr, audio, updated_state in generator:
97
- state = updated_state
98
  if asr:
99
- update_or_append_conversation(state.conversation, id, "user", asr)
 
100
  if text:
101
- update_or_append_conversation(state.conversation, id, "assistant", text)
102
- chatbot_output = state.conversation
103
- yield chatbot_output, audio, state
104
-
105
- # Reset the audio stream for the next interaction
106
- state.stream = None
107
- state.pause_detected = False
108
-
109
- def maybe_call_response(state):
110
- if state.pause_detected:
111
- return response(state)
112
- else:
113
- # Do nothing
114
- return gr.update(), gr.update(), state
115
-
116
- def start_recording_user(state: AppState):
117
- if not state.stopped:
118
- return gr.update(recording=True)
119
- else:
120
- return gr.update(recording=False)
121
-
122
-
123
- def update_format(format, state):
124
- state.output_format = format
125
- return state
126
 
127
  with gr.Blocks() as demo:
128
  with gr.Row():
 
24
  conversation.append({"id": id, "role": role, "content": content})
25
 
26
 
27
+ def generate_response_and_audio(audio_bytes: bytes, lepton_conversation: list[str], client: OpenAI, output_format: str):
28
+ if client is None:
29
  raise gr.Error("Please enter a valid API key first.")
30
 
31
+ bitrate = 128 if output_format == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
 
32
  audio_data = base64.b64encode(audio_bytes).decode()
33
 
34
  try:
 
40
  "tts_audio_bitrate": bitrate
41
  },
42
  model="llama3.1-8b",
43
+ messages=lepton_conversation + [{"role": "user", "content": [{"type": "audio", "data": audio_data}]}],
44
  temperature=0.7,
45
  max_tokens=256,
46
  stream=True,
 
61
 
62
  if asr_results:
63
  asr_result += "".join(asr_results)
64
+ yield id, None, asr_result, None
65
 
66
  if content:
67
  full_response += content
68
+ yield id, full_response, None, None
69
 
70
  if audio:
71
  # Accumulate audio bytes and yield them
72
  audio_bytes_accumulated += b''.join([base64.b64decode(a) for a in audio])
73
+ yield id, None, None, audio_bytes_accumulated
74
 
75
+ yield id, full_response, asr_result, audio_bytes_accumulated
76
 
77
  except Exception as e:
78
  raise gr.Error(f"Error during audio streaming: {e}")
 
80
  def response(audio: tuple[int, np.ndarray], lepton_conversation: list[dict],
81
  gradio_conversation: list[dict], client: OpenAI, output_format: str):
82
 
 
83
  audio_buffer = io.BytesIO()
84
  segment = AudioSegment(
85
  audio[1].tobytes(),
 
91
 
92
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
93
 
94
+ for id, text, asr, audio in generator:
 
95
  if asr:
96
+ update_or_append_conversation(lepton_conversation, id, "user", asr)
97
+ update_or_append_conversation(gradio_conversation, id, "user", asr)
98
  if text:
99
+ update_or_append_conversation(lepton_conversation, id, "assistant", text)
100
+ update_or_append_conversation(gradio_conversation, id, "assistant", text)
101
+
102
+ yield (np.frombuffer(audio, dtype=np.int16).reshape(1, -1), ), AdditionalOutputs(lepton_conversation, gradio_conversation)
103
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  with gr.Blocks() as demo:
106
  with gr.Row():