Spaces:
Running
Running
Update pages.py
Browse files
pages.py
CHANGED
@@ -10,6 +10,11 @@ from streamlit_mic_recorder import mic_recorder
|
|
10 |
from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
def audio_llm():
|
14 |
with st.sidebar:
|
15 |
st.markdown("""<div class="sidebar-intro">
|
@@ -47,7 +52,7 @@ def audio_llm():
|
|
47 |
st.session_state.audio_array = np.array([])
|
48 |
|
49 |
if "default_instruction" not in st.session_state:
|
50 |
-
st.session_state.default_instruction =
|
51 |
|
52 |
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
|
53 |
st.markdown(
|
@@ -60,62 +65,62 @@ def audio_llm():
|
|
60 |
|
61 |
with col1:
|
62 |
audio_samples_w_instruct = {
|
63 |
-
'1_ASR_IMDA_PART1_ASR_v2_141' : "
|
64 |
-
'7_ASR_IMDA_PART3_30_ASR_v2_2269': "
|
65 |
-
'17_ASR_IMDA_PART6_30_ASR_v2_1413': "
|
66 |
|
67 |
-
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': "
|
68 |
-
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': "
|
69 |
-
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': "
|
70 |
|
71 |
-
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': "
|
72 |
-
'33_SQA_IMDA_PART3_30_SQA_V2_2310': "
|
73 |
-
'34_SQA_IMDA_PART3_30_SQA_V2_3621': "
|
74 |
-
'35_SQA_IMDA_PART3_30_SQA_V2_4062': "
|
75 |
-
'36_DS_IMDA_PART4_30_DS_V2_849': "
|
76 |
-
|
77 |
-
'39_Paralingual_IEMOCAP_ER_V2_91': "
|
78 |
-
'40_Paralingual_IEMOCAP_ER_V2_567': "
|
79 |
-
'42_Paralingual_IEMOCAP_GR_V2_320': "
|
80 |
-
'43_Paralingual_IEMOCAP_GR_V2_129': "
|
81 |
-
'45_Paralingual_IMDA_PART3_30_GR_V2_12312': "
|
82 |
-
'47_Paralingual_IMDA_PART3_30_NR_V2_10479': "
|
83 |
-
'49_Paralingual_MELD_ER_V2_676': "
|
84 |
-
'50_Paralingual_MELD_ER_V2_692': "
|
85 |
-
'51_Paralingual_VOXCELEB1_GR_V2_2148': "
|
86 |
-
'53_Paralingual_VOXCELEB1_NR_V2_2286': "
|
87 |
-
|
88 |
-
'55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': "
|
89 |
-
'56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': "
|
90 |
-
'57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': "
|
91 |
|
92 |
-
'2_ASR_IMDA_PART1_ASR_v2_2258': "
|
93 |
-
'3_ASR_IMDA_PART1_ASR_v2_2265': "
|
94 |
|
95 |
-
'4_ASR_IMDA_PART2_ASR_v2_999' : "
|
96 |
-
'5_ASR_IMDA_PART2_ASR_v2_2241': "
|
97 |
-
'6_ASR_IMDA_PART2_ASR_v2_3409': "
|
98 |
|
99 |
-
'8_ASR_IMDA_PART3_30_ASR_v2_1698': "
|
100 |
-
'9_ASR_IMDA_PART3_30_ASR_v2_2474': "
|
101 |
|
102 |
-
'11_ASR_IMDA_PART4_30_ASR_v2_3771': "
|
103 |
-
'12_ASR_IMDA_PART4_30_ASR_v2_103' : "
|
104 |
-
'10_ASR_IMDA_PART4_30_ASR_v2_1527': "
|
105 |
|
106 |
-
'13_ASR_IMDA_PART5_30_ASR_v2_1446': "
|
107 |
-
'14_ASR_IMDA_PART5_30_ASR_v2_2281': "
|
108 |
-
'15_ASR_IMDA_PART5_30_ASR_v2_4388': "
|
109 |
|
110 |
-
'16_ASR_IMDA_PART6_30_ASR_v2_576': "
|
111 |
-
'18_ASR_IMDA_PART6_30_ASR_v2_2834': "
|
112 |
|
113 |
-
'19_ASR_AIShell_zh_ASR_v2_5044': "
|
114 |
-
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': "
|
115 |
|
116 |
-
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': "
|
117 |
-
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': "
|
118 |
-
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': "
|
119 |
}
|
120 |
|
121 |
audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
|
@@ -151,6 +156,7 @@ def audio_llm():
|
|
151 |
|
152 |
if uploaded_file and st.session_state.on_upload:
|
153 |
audio_bytes = uploaded_file.read()
|
|
|
154 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
155 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
156 |
|
@@ -166,19 +172,10 @@ def audio_llm():
|
|
166 |
|
167 |
if recording and st.session_state.on_record:
|
168 |
audio_bytes = recording["bytes"]
|
|
|
169 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
170 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
171 |
|
172 |
-
|
173 |
-
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
|
174 |
-
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
|
175 |
-
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
|
176 |
-
st.session_state.update(on_upload=False, on_record=False, on_select=False)
|
177 |
-
|
178 |
-
if st.session_state.default_instruction:
|
179 |
-
st.write("**Example Instructions:**")
|
180 |
-
st.write(st.session_state.default_instruction)
|
181 |
-
|
182 |
st.markdown(
|
183 |
"""
|
184 |
<style>
|
@@ -187,16 +184,33 @@ def audio_llm():
|
|
187 |
text-align: right;
|
188 |
}
|
189 |
</style>
|
190 |
-
|
191 |
""",
|
192 |
unsafe_allow_html=True,
|
193 |
)
|
194 |
|
195 |
-
if "
|
196 |
-
st.session_state.
|
197 |
|
198 |
if 'disprompt' not in st.session_state:
|
199 |
st.session_state.disprompt = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
for message in st.session_state.messages[-2:]:
|
202 |
with st.chat_message(message["role"]):
|
@@ -212,15 +226,18 @@ def audio_llm():
|
|
212 |
disabled=st.session_state.disprompt,
|
213 |
on_submit=lambda: st.session_state.update(disprompt=True)
|
214 |
):
|
|
|
|
|
|
|
215 |
with st.chat_message("user"):
|
216 |
-
st.write(prompt)
|
217 |
-
st.session_state.messages.append({"role": "user", "content": prompt})
|
218 |
|
219 |
with st.chat_message("assistant"):
|
220 |
response, error_msg, warnings = "", "", []
|
221 |
with st.spinner("Thinking..."):
|
222 |
try:
|
223 |
-
stream, warnings = generate_response(prompt)
|
224 |
for warning_msg in warnings:
|
225 |
st.warning(warning_msg)
|
226 |
response = st.write_stream(stream)
|
@@ -230,12 +247,12 @@ def audio_llm():
|
|
230 |
error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
|
231 |
except Exception as e:
|
232 |
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
st.session_state.disprompt =
|
241 |
st.rerun()
|
|
|
10 |
from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
|
11 |
|
12 |
|
13 |
+
general_instructions = [
|
14 |
+
"Please transcribe this speech.",
|
15 |
+
"Please summarise this speech."
|
16 |
+
]
|
17 |
+
|
18 |
def audio_llm():
|
19 |
with st.sidebar:
|
20 |
st.markdown("""<div class="sidebar-intro">
|
|
|
52 |
st.session_state.audio_array = np.array([])
|
53 |
|
54 |
if "default_instruction" not in st.session_state:
|
55 |
+
st.session_state.default_instruction = []
|
56 |
|
57 |
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
|
58 |
st.markdown(
|
|
|
65 |
|
66 |
with col1:
|
67 |
audio_samples_w_instruct = {
|
68 |
+
'1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
|
69 |
+
'7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
|
70 |
+
'17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
|
71 |
|
72 |
+
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
|
73 |
+
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
|
74 |
+
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
|
75 |
|
76 |
+
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
|
77 |
+
'33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
|
78 |
+
'34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
|
79 |
+
'35_SQA_IMDA_PART3_30_SQA_V2_4062': ["What is the color of the vase mentioned in the dialogue."],
|
80 |
+
'36_DS_IMDA_PART4_30_DS_V2_849': ["Condense the dialogue into a concise summary highlighting major topics and conclusions."],
|
81 |
+
|
82 |
+
'39_Paralingual_IEMOCAP_ER_V2_91': ["Based on the speaker's speech patterns, what do you think they are feeling."],
|
83 |
+
'40_Paralingual_IEMOCAP_ER_V2_567': ["Based on the speaker's speech patterns, what do you think they are feeling."],
|
84 |
+
'42_Paralingual_IEMOCAP_GR_V2_320': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
|
85 |
+
'43_Paralingual_IEMOCAP_GR_V2_129': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
|
86 |
+
'45_Paralingual_IMDA_PART3_30_GR_V2_12312': ["So, who's speaking in the second part of the clip?", "So, who's speaking in the first part of the clip?"],
|
87 |
+
'47_Paralingual_IMDA_PART3_30_NR_V2_10479': ["Can you guess which ethnic group this person is from based on their accent."],
|
88 |
+
'49_Paralingual_MELD_ER_V2_676': ["What emotions do you think the speaker is expressing."],
|
89 |
+
'50_Paralingual_MELD_ER_V2_692': ["Based on the speaker's speech patterns, what do you think they are feeling."],
|
90 |
+
'51_Paralingual_VOXCELEB1_GR_V2_2148': ["May I know the gender of the speaker."],
|
91 |
+
'53_Paralingual_VOXCELEB1_NR_V2_2286': ["What's the nationality identity of the speaker."],
|
92 |
+
|
93 |
+
'55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': ["What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth."],
|
94 |
+
'56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': ["Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore."],
|
95 |
+
'57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': ["How does the author respond to parents' worries about masks in schools."],
|
96 |
|
97 |
+
'2_ASR_IMDA_PART1_ASR_v2_2258': ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
|
98 |
+
'3_ASR_IMDA_PART1_ASR_v2_2265': ["Turn the spoken language into a text format."],
|
99 |
|
100 |
+
'4_ASR_IMDA_PART2_ASR_v2_999' : ["Translate the spoken words into text format."],
|
101 |
+
'5_ASR_IMDA_PART2_ASR_v2_2241': ["Translate the spoken words into text format."],
|
102 |
+
'6_ASR_IMDA_PART2_ASR_v2_3409': ["Translate the spoken words into text format."],
|
103 |
|
104 |
+
'8_ASR_IMDA_PART3_30_ASR_v2_1698': ["Need this talk written down, please."],
|
105 |
+
'9_ASR_IMDA_PART3_30_ASR_v2_2474': ["Need this talk written down, please."],
|
106 |
|
107 |
+
'11_ASR_IMDA_PART4_30_ASR_v2_3771': ["Write out the dialogue as text."],
|
108 |
+
'12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
|
109 |
+
'10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
|
110 |
|
111 |
+
'13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
|
112 |
+
'14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
|
113 |
+
'15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
|
114 |
|
115 |
+
'16_ASR_IMDA_PART6_30_ASR_v2_576': ["Record the spoken word in text form."],
|
116 |
+
'18_ASR_IMDA_PART6_30_ASR_v2_2834': ["Record the spoken word in text form."],
|
117 |
|
118 |
+
'19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
|
119 |
+
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
|
120 |
|
121 |
+
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
|
122 |
+
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
|
123 |
+
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
|
124 |
}
|
125 |
|
126 |
audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
|
|
|
156 |
|
157 |
if uploaded_file and st.session_state.on_upload:
|
158 |
audio_bytes = uploaded_file.read()
|
159 |
+
st.session_state.default_instruction = general_instructions
|
160 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
161 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
162 |
|
|
|
172 |
|
173 |
if recording and st.session_state.on_record:
|
174 |
audio_bytes = recording["bytes"]
|
175 |
+
st.session_state.default_instruction = general_instructions
|
176 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
177 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
st.markdown(
|
180 |
"""
|
181 |
<style>
|
|
|
184 |
text-align: right;
|
185 |
}
|
186 |
</style>
|
|
|
187 |
""",
|
188 |
unsafe_allow_html=True,
|
189 |
)
|
190 |
|
191 |
+
if "prompt" not in st.session_state:
|
192 |
+
st.session_state.prompt = ""
|
193 |
|
194 |
if 'disprompt' not in st.session_state:
|
195 |
st.session_state.disprompt = False
|
196 |
+
|
197 |
+
if "messages" not in st.session_state:
|
198 |
+
st.session_state.messages = []
|
199 |
+
|
200 |
+
if st.session_state.audio_array.size:
|
201 |
+
with st.chat_message("user"):
|
202 |
+
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
|
203 |
+
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
|
204 |
+
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
|
205 |
+
st.session_state.update(on_upload=False, on_record=False, on_select=False)
|
206 |
+
|
207 |
+
for i, inst in enumerate(st.session_state.default_instruction):
|
208 |
+
st.button(
|
209 |
+
f"**Example Instruction {i+1}**: {inst}",
|
210 |
+
args=(inst,),
|
211 |
+
disabled=st.session_state.disprompt,
|
212 |
+
on_click=lambda p: st.session_state.update(disprompt=True, prompt=p)
|
213 |
+
)
|
214 |
|
215 |
for message in st.session_state.messages[-2:]:
|
216 |
with st.chat_message(message["role"]):
|
|
|
226 |
disabled=st.session_state.disprompt,
|
227 |
on_submit=lambda: st.session_state.update(disprompt=True)
|
228 |
):
|
229 |
+
st.session_state.prompt = prompt
|
230 |
+
|
231 |
+
if st.session_state.prompt:
|
232 |
with st.chat_message("user"):
|
233 |
+
st.write(st.session_state.prompt)
|
234 |
+
st.session_state.messages.append({"role": "user", "content": st.session_state.prompt})
|
235 |
|
236 |
with st.chat_message("assistant"):
|
237 |
response, error_msg, warnings = "", "", []
|
238 |
with st.spinner("Thinking..."):
|
239 |
try:
|
240 |
+
stream, warnings = generate_response(st.session_state.prompt)
|
241 |
for warning_msg in warnings:
|
242 |
st.warning(warning_msg)
|
243 |
response = st.write_stream(stream)
|
|
|
247 |
error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
|
248 |
except Exception as e:
|
249 |
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
|
250 |
+
st.session_state.messages.append({
|
251 |
+
"role": "assistant",
|
252 |
+
"error": error_msg,
|
253 |
+
"warnings": warnings,
|
254 |
+
"content": response
|
255 |
+
})
|
256 |
+
|
257 |
+
st.session_state.update(disprompt=False, prompt="")
|
258 |
st.rerun()
|