YingxuHe commited on
Commit
b4ca523
ยท
verified ยท
1 Parent(s): 00b13d9

Update pages.py

Browse files
Files changed (1) hide show
  1. pages.py +83 -76
pages.py CHANGED
@@ -1,8 +1,6 @@
1
- import os
2
  import base64
3
 
4
  import numpy as np
5
- from openai import APIConnectionError
6
  import streamlit as st
7
  from streamlit_mic_recorder import mic_recorder
8
 
@@ -13,62 +11,43 @@ from utils import (
13
  TunnelNotRunningException,
14
  retry_generate_response,
15
  load_model,
16
- generate_response,
17
  bytes_to_array,
18
  start_server,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def audio_llm():
23
- with st.sidebar:
24
- st.markdown("""<div class="sidebar-intro">
25
- <p><strong>๐Ÿ“Œ Supported Tasks</strong>
26
- <p>Automatic Speech Recognation</p>
27
- <p>Speech Translation</p>
28
- <p>Spoken Question Answering</p>
29
- <p>Spoken Dialogue Summarization</p>
30
- <p>Speech Instruction</p>
31
- <p>Paralinguistics</p>
32
- <br>
33
- <p><strong>๐Ÿ“Ž Generation Config</strong>
34
- </div>""", unsafe_allow_html=True)
35
-
36
- st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
37
-
38
- st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
39
-
40
-
41
- if st.sidebar.button('Clear History'):
42
- st.session_state.update(
43
- messages=[],
44
- on_upload=False,
45
- on_record=False,
46
- on_select=False,
47
- audio_base64='',
48
- audio_array=np.array([]),
49
- default_instruction=[]
50
- )
51
-
52
- if "server" not in st.session_state:
53
- st.session_state.server = start_server()
54
-
55
- if "client" not in st.session_state or 'model_name' not in st.session_state:
56
- st.session_state.client, st.session_state.model_name = load_model()
57
-
58
- if "audio_array" not in st.session_state:
59
- st.session_state.audio_base64 = ''
60
- st.session_state.audio_array = np.array([])
61
-
62
- if "default_instruction" not in st.session_state:
63
- st.session_state.default_instruction = []
64
-
65
- st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot ๐Ÿค–</h1>", unsafe_allow_html=True)
66
- st.markdown(
67
- """This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
68
- developed by I2R, A*STAR, in collaboration with AISG, Singapore.
69
- It is tailored for Singaporeโ€™s multilingual and multicultural landscape."""
70
- )
71
-
72
  col1, col2, col3 = st.columns([3.5, 4, 1.5])
73
 
74
  with col1:
@@ -82,7 +61,7 @@ def audio_llm():
82
  options=audio_sample_names,
83
  index=None,
84
  placeholder="Select an audio sample:",
85
- on_change=lambda: st.session_state.update(on_select=True, messages=[]),
86
  key='select')
87
 
88
  if sample_name and st.session_state.on_select:
@@ -99,7 +78,7 @@ def audio_llm():
99
  label="**Upload Audio:**",
100
  label_visibility="collapsed",
101
  type=['wav', 'mp3'],
102
- on_change=lambda: st.session_state.update(on_upload=True, messages=[]),
103
  key='upload'
104
  )
105
 
@@ -118,7 +97,7 @@ def audio_llm():
118
  stop_prompt="๐Ÿ”ด stop recording",
119
  format="wav",
120
  use_container_width=True,
121
- callback=lambda: st.session_state.update(on_record=True, messages=[]),
122
  key='record')
123
 
124
  if recording and st.session_state.on_record:
@@ -127,31 +106,29 @@ def audio_llm():
127
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
128
  st.session_state.audio_array = bytes_to_array(audio_bytes)
129
 
130
- if "prompt" not in st.session_state:
131
- st.session_state.prompt = ""
132
 
133
- if 'disprompt' not in st.session_state:
134
- st.session_state.disprompt = False
135
-
136
- if "messages" not in st.session_state:
137
- st.session_state.messages = []
138
-
139
  if st.session_state.audio_array.size:
140
  with st.chat_message("user"):
141
  if st.session_state.audio_array.shape[0] / 16000 > 30.0:
142
  st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
143
 
144
  st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
145
- st.session_state.update(on_upload=False, on_record=False, on_select=False)
146
-
147
  for i, inst in enumerate(st.session_state.default_instruction):
148
  st.button(
149
  f"**Example Instruction {i+1}**: {inst}",
150
  args=(inst,),
151
  disabled=st.session_state.disprompt,
152
- on_click=lambda p: st.session_state.update(disprompt=True, prompt=p)
153
  )
154
 
 
 
 
 
 
 
155
  for message in st.session_state.messages[-2:]:
156
  with st.chat_message(message["role"]):
157
  if message.get("error"):
@@ -160,24 +137,26 @@ def audio_llm():
160
  st.warning(warning_msg)
161
  if message.get("content"):
162
  st.write(message["content"])
163
-
164
- if prompt := st.chat_input(
165
  placeholder="Type Your Instruction Here",
166
  disabled=st.session_state.disprompt,
167
  on_submit=lambda: st.session_state.update(disprompt=True)
168
  ):
169
- st.session_state.prompt = prompt
 
 
 
170
 
171
- if st.session_state.prompt:
172
  with st.chat_message("user"):
173
- st.write(st.session_state.prompt)
174
- st.session_state.messages.append({"role": "user", "content": st.session_state.prompt})
175
 
176
  with st.chat_message("assistant"):
177
  with st.spinner("Thinking..."):
178
  error_msg, warnings, response = "", [], ""
179
  try:
180
- response, warnings = retry_generate_response()
181
  except NoAudioException:
182
  error_msg = "Please specify audio first!"
183
  except TunnelNotRunningException:
@@ -191,5 +170,33 @@ def audio_llm():
191
  "content": response
192
  })
193
 
194
- st.session_state.update(disprompt=False, prompt="")
195
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
 
3
  import numpy as np
 
4
  import streamlit as st
5
  from streamlit_mic_recorder import mic_recorder
6
 
 
11
  TunnelNotRunningException,
12
  retry_generate_response,
13
  load_model,
 
14
  bytes_to_array,
15
  start_server,
16
  )
17
 
18
+ DEFAULT_DIALOGUE_STATES = dict(
19
+ default_instruction=[],
20
+ audio_base64='',
21
+ audio_array=np.array([]),
22
+ disprompt = False,
23
+ new_prompt = "",
24
+ messages=[],
25
+ on_select=False,
26
+ on_upload=False,
27
+ on_record=False,
28
+ on_click_button = False
29
+ )
30
 
31
+ @st.fragment
32
+ def sidebar_fragment():
33
+ st.markdown("""<div class="sidebar-intro">
34
+ <p><strong>๐Ÿ“Œ Supported Tasks</strong>
35
+ <p>Automatic Speech Recognation</p>
36
+ <p>Speech Translation</p>
37
+ <p>Spoken Question Answering</p>
38
+ <p>Spoken Dialogue Summarization</p>
39
+ <p>Speech Instruction</p>
40
+ <p>Paralinguistics</p>
41
+ <br>
42
+ <p><strong>๐Ÿ“Ž Generation Config</strong>
43
+ </div>""", unsafe_allow_html=True)
44
+
45
+ st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
46
+
47
+ st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
48
+
49
+ @st.fragment
50
+ def specify_audio_fragment():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  col1, col2, col3 = st.columns([3.5, 4, 1.5])
52
 
53
  with col1:
 
61
  options=audio_sample_names,
62
  index=None,
63
  placeholder="Select an audio sample:",
64
+ on_change=lambda: st.session_state.update(on_select=True),
65
  key='select')
66
 
67
  if sample_name and st.session_state.on_select:
 
78
  label="**Upload Audio:**",
79
  label_visibility="collapsed",
80
  type=['wav', 'mp3'],
81
+ on_change=lambda: st.session_state.update(on_upload=True),
82
  key='upload'
83
  )
84
 
 
97
  stop_prompt="๐Ÿ”ด stop recording",
98
  format="wav",
99
  use_container_width=True,
100
+ callback=lambda: st.session_state.update(on_record=True),
101
  key='record')
102
 
103
  if recording and st.session_state.on_record:
 
106
  st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
107
  st.session_state.audio_array = bytes_to_array(audio_bytes)
108
 
109
+ st.session_state.update(on_upload=False, on_record=False, on_select=False)
 
110
 
 
 
 
 
 
 
111
  if st.session_state.audio_array.size:
112
  with st.chat_message("user"):
113
  if st.session_state.audio_array.shape[0] / 16000 > 30.0:
114
  st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
115
 
116
  st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
117
+
 
118
  for i, inst in enumerate(st.session_state.default_instruction):
119
  st.button(
120
  f"**Example Instruction {i+1}**: {inst}",
121
  args=(inst,),
122
  disabled=st.session_state.disprompt,
123
+ on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True)
124
  )
125
 
126
+ if st.session_state.on_click_button:
127
+ st.session_state.on_click_button = False
128
+ st.rerun(scope="app")
129
+
130
+
131
+ def dialogue_section():
132
  for message in st.session_state.messages[-2:]:
133
  with st.chat_message(message["role"]):
134
  if message.get("error"):
 
137
  st.warning(warning_msg)
138
  if message.get("content"):
139
  st.write(message["content"])
140
+
141
+ if chat_input := st.chat_input(
142
  placeholder="Type Your Instruction Here",
143
  disabled=st.session_state.disprompt,
144
  on_submit=lambda: st.session_state.update(disprompt=True)
145
  ):
146
+ st.session_state.new_prompt = chat_input
147
+
148
+ if one_time_prompt := st.session_state.new_prompt:
149
+ st.session_state.new_prompt = ""
150
 
 
151
  with st.chat_message("user"):
152
+ st.write(one_time_prompt)
153
+ st.session_state.messages.append({"role": "user", "content": one_time_prompt})
154
 
155
  with st.chat_message("assistant"):
156
  with st.spinner("Thinking..."):
157
  error_msg, warnings, response = "", [], ""
158
  try:
159
+ response, warnings = retry_generate_response(one_time_prompt)
160
  except NoAudioException:
161
  error_msg = "Please specify audio first!"
162
  except TunnelNotRunningException:
 
170
  "content": response
171
  })
172
 
173
+ st.session_state.disprompt=False
174
+ st.rerun(scope="app")
175
+
176
+
177
+ def audio_llm():
178
+ if "server" not in st.session_state:
179
+ st.session_state.server = start_server()
180
+
181
+ if "client" not in st.session_state or 'model_name' not in st.session_state:
182
+ st.session_state.client, st.session_state.model_name = load_model()
183
+
184
+ for key, value in DEFAULT_DIALOGUE_STATES.items():
185
+ if key not in st.session_state:
186
+ st.session_state[key]=value
187
+
188
+ with st.sidebar:
189
+ sidebar_fragment()
190
+
191
+ if st.sidebar.button('Clear History'):
192
+ st.session_state.update(DEFAULT_DIALOGUE_STATES)
193
+
194
+ st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot ๐Ÿค–</h1>", unsafe_allow_html=True)
195
+ st.markdown(
196
+ """This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
197
+ developed by I2R, A*STAR, in collaboration with AISG, Singapore.
198
+ It is tailored for Singaporeโ€™s multilingual and multicultural landscape."""
199
+ )
200
+
201
+ specify_audio_fragment()
202
+ dialogue_section()