YingxuHe commited on
Commit
7234479
·
verified ·
1 Parent(s): d1ebb9a

Update pages.py

Browse files
Files changed (1) hide show
  1. pages.py +88 -71
pages.py CHANGED
@@ -10,6 +10,11 @@ from streamlit_mic_recorder import mic_recorder
10
  from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
11
 
12
 
 
 
 
 
 
13
  def audio_llm():
14
  with st.sidebar:
15
  st.markdown("""<div class="sidebar-intro">
@@ -47,7 +52,7 @@ def audio_llm():
47
  st.session_state.audio_array = np.array([])
48
 
49
  if "default_instruction" not in st.session_state:
50
- st.session_state.default_instruction = ""
51
 
52
  st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
53
  st.markdown(
@@ -60,62 +65,62 @@ def audio_llm():
60
 
61
  with col1:
62
  audio_samples_w_instruct = {
63
- '1_ASR_IMDA_PART1_ASR_v2_141' : "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.",
64
- '7_ASR_IMDA_PART3_30_ASR_v2_2269': "- Need this talk written down, please.",
65
- '17_ASR_IMDA_PART6_30_ASR_v2_1413': "- Record the spoken word in text form.",
66
 
67
- '25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': "- Please translate the given speech to English.",
68
- '26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': "- Please translate the given speech to Chinese.",
69
- '30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': "- Please follow the instruction in the speech.",
70
 
71
- '32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': "- What does the man think the woman should do at 4:00?",
72
- '33_SQA_IMDA_PART3_30_SQA_V2_2310': "- Does Speaker2's wife cook for Speaker2 when they are at home?",
73
- '34_SQA_IMDA_PART3_30_SQA_V2_3621': "- Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language?",
74
- '35_SQA_IMDA_PART3_30_SQA_V2_4062': "- What is the color of the vase mentioned in the dialogue?",
75
- '36_DS_IMDA_PART4_30_DS_V2_849': "- Condense the dialogue into a concise summary highlighting major topics and conclusions.",
76
-
77
- '39_Paralingual_IEMOCAP_ER_V2_91': "- Based on the speaker's speech patterns, what do you think they are feeling?",
78
- '40_Paralingual_IEMOCAP_ER_V2_567': "- Based on the speaker's speech patterns, what do you think they are feeling?",
79
- '42_Paralingual_IEMOCAP_GR_V2_320': "- Is it possible for you to identify whether the speaker in this recording is male or female?",
80
- '43_Paralingual_IEMOCAP_GR_V2_129': "- Is it possible for you to identify whether the speaker in this recording is male or female?",
81
- '45_Paralingual_IMDA_PART3_30_GR_V2_12312': "- So, who's speaking in the second part of the clip? \n\n- So, who's speaking in the first part of the clip?",
82
- '47_Paralingual_IMDA_PART3_30_NR_V2_10479': "- Can you guess which ethnic group this person is from based on their accent?",
83
- '49_Paralingual_MELD_ER_V2_676': "- What emotions do you think the speaker is expressing?",
84
- '50_Paralingual_MELD_ER_V2_692': "- Based on the speaker's speech patterns, what do you think they are feeling?",
85
- '51_Paralingual_VOXCELEB1_GR_V2_2148': "- May I know the gender of the speaker?",
86
- '53_Paralingual_VOXCELEB1_NR_V2_2286': "- What's the nationality identity of the speaker?",
87
-
88
- '55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': "- What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth?",
89
- '56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': "- Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore?",
90
- '57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': "- How does the author respond to parents' worries about masks in schools?",
91
 
92
- '2_ASR_IMDA_PART1_ASR_v2_2258': "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.",
93
- '3_ASR_IMDA_PART1_ASR_v2_2265': "- Turn the spoken language into a text format.",
94
 
95
- '4_ASR_IMDA_PART2_ASR_v2_999' : "- Translate the spoken words into text format.",
96
- '5_ASR_IMDA_PART2_ASR_v2_2241': "- Translate the spoken words into text format.",
97
- '6_ASR_IMDA_PART2_ASR_v2_3409': "- Translate the spoken words into text format.",
98
 
99
- '8_ASR_IMDA_PART3_30_ASR_v2_1698': "- Need this talk written down, please.",
100
- '9_ASR_IMDA_PART3_30_ASR_v2_2474': "- Need this talk written down, please.",
101
 
102
- '11_ASR_IMDA_PART4_30_ASR_v2_3771': "- Write out the dialogue as text.",
103
- '12_ASR_IMDA_PART4_30_ASR_v2_103' : "- Write out the dialogue as text.",
104
- '10_ASR_IMDA_PART4_30_ASR_v2_1527': "- Write out the dialogue as text.",
105
 
106
- '13_ASR_IMDA_PART5_30_ASR_v2_1446': "- Translate this vocal recording into a textual format.",
107
- '14_ASR_IMDA_PART5_30_ASR_v2_2281': "- Translate this vocal recording into a textual format.",
108
- '15_ASR_IMDA_PART5_30_ASR_v2_4388': "- Translate this vocal recording into a textual format.",
109
 
110
- '16_ASR_IMDA_PART6_30_ASR_v2_576': "- Record the spoken word in text form.",
111
- '18_ASR_IMDA_PART6_30_ASR_v2_2834': "- Record the spoken word in text form.",
112
 
113
- '19_ASR_AIShell_zh_ASR_v2_5044': "- Transform the oral presentation into a text document.",
114
- '20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': "- Please provide a written transcription of the speech.",
115
 
116
- '27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': "- Please translate the given speech to Chinese.",
117
- '28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': "- Please follow the instruction in the speech.",
118
- '29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': "- Please follow the instruction in the speech.",
119
  }
120
 
121
  audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
@@ -151,6 +156,7 @@ def audio_llm():
151
 
152
  if uploaded_file and st.session_state.on_upload:
153
  audio_bytes = uploaded_file.read()
 
154
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
155
  st.session_state.audio_array = bytes_to_array(audio_bytes)
156
 
@@ -166,19 +172,10 @@ def audio_llm():
166
 
167
  if recording and st.session_state.on_record:
168
  audio_bytes = recording["bytes"]
 
169
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
170
  st.session_state.audio_array = bytes_to_array(audio_bytes)
171
 
172
-
173
- st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
174
- if st.session_state.audio_array.shape[0] / 16000 > 30.0:
175
- st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
176
- st.session_state.update(on_upload=False, on_record=False, on_select=False)
177
-
178
- if st.session_state.default_instruction:
179
- st.write("**Example Instructions:**")
180
- st.write(st.session_state.default_instruction)
181
-
182
  st.markdown(
183
  """
184
  <style>
@@ -187,16 +184,33 @@ def audio_llm():
187
  text-align: right;
188
  }
189
  </style>
190
-
191
  """,
192
  unsafe_allow_html=True,
193
  )
194
 
195
- if "messages" not in st.session_state:
196
- st.session_state.messages = []
197
 
198
  if 'disprompt' not in st.session_state:
199
  st.session_state.disprompt = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  for message in st.session_state.messages[-2:]:
202
  with st.chat_message(message["role"]):
@@ -212,15 +226,18 @@ def audio_llm():
212
  disabled=st.session_state.disprompt,
213
  on_submit=lambda: st.session_state.update(disprompt=True)
214
  ):
 
 
 
215
  with st.chat_message("user"):
216
- st.write(prompt)
217
- st.session_state.messages.append({"role": "user", "content": prompt})
218
 
219
  with st.chat_message("assistant"):
220
  response, error_msg, warnings = "", "", []
221
  with st.spinner("Thinking..."):
222
  try:
223
- stream, warnings = generate_response(prompt)
224
  for warning_msg in warnings:
225
  st.warning(warning_msg)
226
  response = st.write_stream(stream)
@@ -230,12 +247,12 @@ def audio_llm():
230
  error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
231
  except Exception as e:
232
  error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
233
- st.session_state.messages.append({
234
- "role": "assistant",
235
- "error": error_msg,
236
- "warnings": warnings,
237
- "content": response
238
- })
239
-
240
- st.session_state.disprompt = False
241
  st.rerun()
 
10
  from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
11
 
12
 
13
+ general_instructions = [
14
+ "Please transcribe this speech.",
15
+ "Please summarise this speech."
16
+ ]
17
+
18
  def audio_llm():
19
  with st.sidebar:
20
  st.markdown("""<div class="sidebar-intro">
 
52
  st.session_state.audio_array = np.array([])
53
 
54
  if "default_instruction" not in st.session_state:
55
+ st.session_state.default_instruction = []
56
 
57
  st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
58
  st.markdown(
 
65
 
66
  with col1:
67
  audio_samples_w_instruct = {
68
+ '1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
69
+ '7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
70
+ '17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
71
 
72
+ '25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
73
+ '26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
74
+ '30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
75
 
76
+ '32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
77
+ '33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
78
+ '34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
79
+ '35_SQA_IMDA_PART3_30_SQA_V2_4062': ["What is the color of the vase mentioned in the dialogue."],
80
+ '36_DS_IMDA_PART4_30_DS_V2_849': ["Condense the dialogue into a concise summary highlighting major topics and conclusions."],
81
+
82
+ '39_Paralingual_IEMOCAP_ER_V2_91': ["Based on the speaker's speech patterns, what do you think they are feeling."],
83
+ '40_Paralingual_IEMOCAP_ER_V2_567': ["Based on the speaker's speech patterns, what do you think they are feeling."],
84
+ '42_Paralingual_IEMOCAP_GR_V2_320': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
85
+ '43_Paralingual_IEMOCAP_GR_V2_129': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
86
+ '45_Paralingual_IMDA_PART3_30_GR_V2_12312': ["So, who's speaking in the second part of the clip?", "So, who's speaking in the first part of the clip?"],
87
+ '47_Paralingual_IMDA_PART3_30_NR_V2_10479': ["Can you guess which ethnic group this person is from based on their accent."],
88
+ '49_Paralingual_MELD_ER_V2_676': ["What emotions do you think the speaker is expressing."],
89
+ '50_Paralingual_MELD_ER_V2_692': ["Based on the speaker's speech patterns, what do you think they are feeling."],
90
+ '51_Paralingual_VOXCELEB1_GR_V2_2148': ["May I know the gender of the speaker."],
91
+ '53_Paralingual_VOXCELEB1_NR_V2_2286': ["What's the nationality identity of the speaker."],
92
+
93
+ '55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': ["What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth."],
94
+ '56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': ["Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore."],
95
+ '57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': ["How does the author respond to parents' worries about masks in schools."],
96
 
97
+ '2_ASR_IMDA_PART1_ASR_v2_2258': ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
98
+ '3_ASR_IMDA_PART1_ASR_v2_2265': ["Turn the spoken language into a text format."],
99
 
100
+ '4_ASR_IMDA_PART2_ASR_v2_999' : ["Translate the spoken words into text format."],
101
+ '5_ASR_IMDA_PART2_ASR_v2_2241': ["Translate the spoken words into text format."],
102
+ '6_ASR_IMDA_PART2_ASR_v2_3409': ["Translate the spoken words into text format."],
103
 
104
+ '8_ASR_IMDA_PART3_30_ASR_v2_1698': ["Need this talk written down, please."],
105
+ '9_ASR_IMDA_PART3_30_ASR_v2_2474': ["Need this talk written down, please."],
106
 
107
+ '11_ASR_IMDA_PART4_30_ASR_v2_3771': ["Write out the dialogue as text."],
108
+ '12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
109
+ '10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
110
 
111
+ '13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
112
+ '14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
113
+ '15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
114
 
115
+ '16_ASR_IMDA_PART6_30_ASR_v2_576': ["Record the spoken word in text form."],
116
+ '18_ASR_IMDA_PART6_30_ASR_v2_2834': ["Record the spoken word in text form."],
117
 
118
+ '19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
119
+ '20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
120
 
121
+ '27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
122
+ '28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
123
+ '29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
124
  }
125
 
126
  audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
 
156
 
157
  if uploaded_file and st.session_state.on_upload:
158
  audio_bytes = uploaded_file.read()
159
+ st.session_state.default_instruction = general_instructions
160
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
161
  st.session_state.audio_array = bytes_to_array(audio_bytes)
162
 
 
172
 
173
  if recording and st.session_state.on_record:
174
  audio_bytes = recording["bytes"]
175
+ st.session_state.default_instruction = general_instructions
176
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
177
  st.session_state.audio_array = bytes_to_array(audio_bytes)
178
 
 
 
 
 
 
 
 
 
 
 
179
  st.markdown(
180
  """
181
  <style>
 
184
  text-align: right;
185
  }
186
  </style>
 
187
  """,
188
  unsafe_allow_html=True,
189
  )
190
 
191
+ if "prompt" not in st.session_state:
192
+ st.session_state.prompt = ""
193
 
194
  if 'disprompt' not in st.session_state:
195
  st.session_state.disprompt = False
196
+
197
+ if "messages" not in st.session_state:
198
+ st.session_state.messages = []
199
+
200
+ if st.session_state.audio_array.size:
201
+ with st.chat_message("user"):
202
+ st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
203
+ if st.session_state.audio_array.shape[0] / 16000 > 30.0:
204
+ st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
205
+ st.session_state.update(on_upload=False, on_record=False, on_select=False)
206
+
207
+ for i, inst in enumerate(st.session_state.default_instruction):
208
+ st.button(
209
+ f"**Example Instruction {i+1}**: {inst}",
210
+ args=(inst,),
211
+ disabled=st.session_state.disprompt,
212
+ on_click=lambda p: st.session_state.update(disprompt=True, prompt=p)
213
+ )
214
 
215
  for message in st.session_state.messages[-2:]:
216
  with st.chat_message(message["role"]):
 
226
  disabled=st.session_state.disprompt,
227
  on_submit=lambda: st.session_state.update(disprompt=True)
228
  ):
229
+ st.session_state.prompt = prompt
230
+
231
+ if st.session_state.prompt:
232
  with st.chat_message("user"):
233
+ st.write(st.session_state.prompt)
234
+ st.session_state.messages.append({"role": "user", "content": st.session_state.prompt})
235
 
236
  with st.chat_message("assistant"):
237
  response, error_msg, warnings = "", "", []
238
  with st.spinner("Thinking..."):
239
  try:
240
+ stream, warnings = generate_response(st.session_state.prompt)
241
  for warning_msg in warnings:
242
  st.warning(warning_msg)
243
  response = st.write_stream(stream)
 
247
  error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
248
  except Exception as e:
249
  error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
250
+ st.session_state.messages.append({
251
+ "role": "assistant",
252
+ "error": error_msg,
253
+ "warnings": warnings,
254
+ "content": response
255
+ })
256
+
257
+ st.session_state.update(disprompt=False, prompt="")
258
  st.rerun()