YingxuHe commited on
Commit
6a8f361
·
verified ·
1 Parent(s): ddaf4f0

Update pages.py

Browse files
Files changed (1) hide show
  1. pages.py +63 -46
pages.py CHANGED
@@ -2,11 +2,12 @@ import os
2
  import base64
3
 
4
  import numpy as np
 
5
  import streamlit as st
6
  import streamlit.components.v1 as components
7
  from streamlit_mic_recorder import mic_recorder
8
 
9
- from utils import load_model, generate_response, bytes_to_array, start_server
10
 
11
 
12
  def audio_llm():
@@ -56,36 +57,6 @@ def audio_llm():
56
  )
57
 
58
  col1, col2, col3 = st.columns([4, 4, 1.2])
59
-
60
- with col3:
61
- st.markdown("or **Record Audio:**")
62
-
63
- recording = mic_recorder(
64
- format="wav",
65
- use_container_width=True,
66
- callback=lambda: st.session_state.update(on_record=True, messages=[]),
67
- key='record')
68
-
69
- if recording and st.session_state.on_record:
70
- audio_bytes = recording["bytes"]
71
- st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
72
- st.session_state.audio_array = bytes_to_array(audio_bytes)
73
-
74
- with col2:
75
- st.markdown("or **Upload Audio:**")
76
-
77
- uploaded_file = st.file_uploader(
78
- label="**Upload Audio:**",
79
- label_visibility="collapsed",
80
- type=['wav', 'mp3'],
81
- on_change=lambda: st.session_state.update(on_upload=True, messages=[]),
82
- key='upload'
83
- )
84
-
85
- if uploaded_file and st.session_state.on_upload:
86
- audio_bytes = uploaded_file.read()
87
- st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
88
- st.session_state.audio_array = bytes_to_array(audio_bytes)
89
 
90
  with col1:
91
  audio_samples_w_instruct = {
@@ -165,8 +136,43 @@ def audio_llm():
165
  st.session_state.default_instruction = audio_samples_w_instruct[sample_name]
166
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
167
  st.session_state.audio_array = bytes_to_array(audio_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
 
169
  st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
 
 
170
  st.session_state.update(on_upload=False, on_record=False, on_select=False)
171
 
172
  if st.session_state.default_instruction:
@@ -194,7 +200,12 @@ def audio_llm():
194
 
195
  for message in st.session_state.messages[-2:]:
196
  with st.chat_message(message["role"]):
197
- st.write(message["content"])
 
 
 
 
 
198
 
199
  if prompt := st.chat_input(
200
  placeholder="Type Your Instruction Here",
@@ -206,19 +217,25 @@ def audio_llm():
206
  st.session_state.messages.append({"role": "user", "content": prompt})
207
 
208
  with st.chat_message("assistant"):
209
- if not st.session_state.audio_base64:
210
- response = "Please specify audio first!"
211
- st.write(response)
212
- else:
213
- with st.spinner("Thinking..."):
214
- try:
215
- stream = generate_response(prompt)
216
- response = st.write_stream(stream)
217
- except Exception as e:
218
- response = f"Caught Exception: {repr(e)}. Please contact the administrator to restart this space."
219
- st.write(response)
220
- raise(e)
221
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
222
 
223
  st.session_state.disprompt = False
224
- st.rerun()
 
2
  import base64
3
 
4
  import numpy as np
5
+ from openai import APIConnectionError
6
  import streamlit as st
7
  import streamlit.components.v1 as components
8
  from streamlit_mic_recorder import mic_recorder
9
 
10
+ from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
11
 
12
 
13
  def audio_llm():
 
57
  )
58
 
59
  col1, col2, col3 = st.columns([4, 4, 1.2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  with col1:
62
  audio_samples_w_instruct = {
 
136
  st.session_state.default_instruction = audio_samples_w_instruct[sample_name]
137
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
138
  st.session_state.audio_array = bytes_to_array(audio_bytes)
139
+
140
+
141
+ with col2:
142
+ st.markdown("or **Upload Audio:**")
143
+
144
+ uploaded_file = st.file_uploader(
145
+ label="**Upload Audio:**",
146
+ label_visibility="collapsed",
147
+ type=['wav', 'mp3'],
148
+ on_change=lambda: st.session_state.update(on_upload=True, messages=[]),
149
+ key='upload'
150
+ )
151
+
152
+ if uploaded_file and st.session_state.on_upload:
153
+ audio_bytes = uploaded_file.read()
154
+ st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
155
+ st.session_state.audio_array = bytes_to_array(audio_bytes)
156
+
157
+
158
+ with col3:
159
+ st.markdown("or **Record Audio:**")
160
+
161
+ recording = mic_recorder(
162
+ format="wav",
163
+ use_container_width=True,
164
+ callback=lambda: st.session_state.update(on_record=True, messages=[]),
165
+ key='record')
166
+
167
+ if recording and st.session_state.on_record:
168
+ audio_bytes = recording["bytes"]
169
+ st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
170
+ st.session_state.audio_array = bytes_to_array(audio_bytes)
171
 
172
+
173
  st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
174
+ if st.session_state.audio_array.shape[0] / 16000 > 30.0:
175
+ st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
176
  st.session_state.update(on_upload=False, on_record=False, on_select=False)
177
 
178
  if st.session_state.default_instruction:
 
200
 
201
  for message in st.session_state.messages[-2:]:
202
  with st.chat_message(message["role"]):
203
+ if message.get("error"):
204
+ st.error(message["error"])
205
+ for warning_msg in message.get("warnings", []):
206
+ st.warning(warning_msg)
207
+ if message.get("content"):
208
+ st.write(message["content"])
209
 
210
  if prompt := st.chat_input(
211
  placeholder="Type Your Instruction Here",
 
217
  st.session_state.messages.append({"role": "user", "content": prompt})
218
 
219
  with st.chat_message("assistant"):
220
+ response, error_msg, warnings = "", "", []
221
+ with st.spinner("Thinking..."):
222
+ try:
223
+ stream, warnings = generate_response(prompt)
224
+ for warning_msg in warnings:
225
+ st.warning(warning_msg)
226
+ response = st.write_stream(stream)
227
+ except NoAudioException:
228
+ error_msg = "Please specify audio first!"
229
+ except APIConnectionError:
230
+ error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
231
+ except Exception as e:
232
+ error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
233
+ st.session_state.messages.append({
234
+ "role": "assistant",
235
+ "error": error_msg,
236
+ "warnings": warnings,
237
+ "content": response
238
+ })
239
 
240
  st.session_state.disprompt = False
241
+ st.rerun()