Spaces:
Running
Running
Update pages.py
Browse files
pages.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
-
import os
|
2 |
import base64
|
3 |
|
4 |
import numpy as np
|
5 |
-
from openai import APIConnectionError
|
6 |
import streamlit as st
|
7 |
from streamlit_mic_recorder import mic_recorder
|
8 |
|
@@ -13,62 +11,43 @@ from utils import (
|
|
13 |
TunnelNotRunningException,
|
14 |
retry_generate_response,
|
15 |
load_model,
|
16 |
-
generate_response,
|
17 |
bytes_to_array,
|
18 |
start_server,
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
st.session_state.update(
|
43 |
-
messages=[],
|
44 |
-
on_upload=False,
|
45 |
-
on_record=False,
|
46 |
-
on_select=False,
|
47 |
-
audio_base64='',
|
48 |
-
audio_array=np.array([]),
|
49 |
-
default_instruction=[]
|
50 |
-
)
|
51 |
-
|
52 |
-
if "server" not in st.session_state:
|
53 |
-
st.session_state.server = start_server()
|
54 |
-
|
55 |
-
if "client" not in st.session_state or 'model_name' not in st.session_state:
|
56 |
-
st.session_state.client, st.session_state.model_name = load_model()
|
57 |
-
|
58 |
-
if "audio_array" not in st.session_state:
|
59 |
-
st.session_state.audio_base64 = ''
|
60 |
-
st.session_state.audio_array = np.array([])
|
61 |
-
|
62 |
-
if "default_instruction" not in st.session_state:
|
63 |
-
st.session_state.default_instruction = []
|
64 |
-
|
65 |
-
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot ๐ค</h1>", unsafe_allow_html=True)
|
66 |
-
st.markdown(
|
67 |
-
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
|
68 |
-
developed by I2R, A*STAR, in collaboration with AISG, Singapore.
|
69 |
-
It is tailored for Singaporeโs multilingual and multicultural landscape."""
|
70 |
-
)
|
71 |
-
|
72 |
col1, col2, col3 = st.columns([3.5, 4, 1.5])
|
73 |
|
74 |
with col1:
|
@@ -82,7 +61,7 @@ def audio_llm():
|
|
82 |
options=audio_sample_names,
|
83 |
index=None,
|
84 |
placeholder="Select an audio sample:",
|
85 |
-
on_change=lambda: st.session_state.update(on_select=True
|
86 |
key='select')
|
87 |
|
88 |
if sample_name and st.session_state.on_select:
|
@@ -99,7 +78,7 @@ def audio_llm():
|
|
99 |
label="**Upload Audio:**",
|
100 |
label_visibility="collapsed",
|
101 |
type=['wav', 'mp3'],
|
102 |
-
on_change=lambda: st.session_state.update(on_upload=True
|
103 |
key='upload'
|
104 |
)
|
105 |
|
@@ -118,7 +97,7 @@ def audio_llm():
|
|
118 |
stop_prompt="๐ด stop recording",
|
119 |
format="wav",
|
120 |
use_container_width=True,
|
121 |
-
callback=lambda: st.session_state.update(on_record=True
|
122 |
key='record')
|
123 |
|
124 |
if recording and st.session_state.on_record:
|
@@ -127,31 +106,29 @@ def audio_llm():
|
|
127 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
128 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
129 |
|
130 |
-
|
131 |
-
st.session_state.prompt = ""
|
132 |
|
133 |
-
if 'disprompt' not in st.session_state:
|
134 |
-
st.session_state.disprompt = False
|
135 |
-
|
136 |
-
if "messages" not in st.session_state:
|
137 |
-
st.session_state.messages = []
|
138 |
-
|
139 |
if st.session_state.audio_array.size:
|
140 |
with st.chat_message("user"):
|
141 |
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
|
142 |
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
|
143 |
|
144 |
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
|
145 |
-
|
146 |
-
|
147 |
for i, inst in enumerate(st.session_state.default_instruction):
|
148 |
st.button(
|
149 |
f"**Example Instruction {i+1}**: {inst}",
|
150 |
args=(inst,),
|
151 |
disabled=st.session_state.disprompt,
|
152 |
-
on_click=lambda p: st.session_state.update(disprompt=True,
|
153 |
)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
for message in st.session_state.messages[-2:]:
|
156 |
with st.chat_message(message["role"]):
|
157 |
if message.get("error"):
|
@@ -160,24 +137,26 @@ def audio_llm():
|
|
160 |
st.warning(warning_msg)
|
161 |
if message.get("content"):
|
162 |
st.write(message["content"])
|
163 |
-
|
164 |
-
if
|
165 |
placeholder="Type Your Instruction Here",
|
166 |
disabled=st.session_state.disprompt,
|
167 |
on_submit=lambda: st.session_state.update(disprompt=True)
|
168 |
):
|
169 |
-
st.session_state.
|
|
|
|
|
|
|
170 |
|
171 |
-
if st.session_state.prompt:
|
172 |
with st.chat_message("user"):
|
173 |
-
st.write(
|
174 |
-
st.session_state.messages.append({"role": "user", "content":
|
175 |
|
176 |
with st.chat_message("assistant"):
|
177 |
with st.spinner("Thinking..."):
|
178 |
error_msg, warnings, response = "", [], ""
|
179 |
try:
|
180 |
-
response, warnings = retry_generate_response()
|
181 |
except NoAudioException:
|
182 |
error_msg = "Please specify audio first!"
|
183 |
except TunnelNotRunningException:
|
@@ -191,5 +170,33 @@ def audio_llm():
|
|
191 |
"content": response
|
192 |
})
|
193 |
|
194 |
-
st.session_state.
|
195 |
-
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
|
3 |
import numpy as np
|
|
|
4 |
import streamlit as st
|
5 |
from streamlit_mic_recorder import mic_recorder
|
6 |
|
|
|
11 |
TunnelNotRunningException,
|
12 |
retry_generate_response,
|
13 |
load_model,
|
|
|
14 |
bytes_to_array,
|
15 |
start_server,
|
16 |
)
|
17 |
|
18 |
+
DEFAULT_DIALOGUE_STATES = dict(
|
19 |
+
default_instruction=[],
|
20 |
+
audio_base64='',
|
21 |
+
audio_array=np.array([]),
|
22 |
+
disprompt = False,
|
23 |
+
new_prompt = "",
|
24 |
+
messages=[],
|
25 |
+
on_select=False,
|
26 |
+
on_upload=False,
|
27 |
+
on_record=False,
|
28 |
+
on_click_button = False
|
29 |
+
)
|
30 |
|
31 |
+
@st.fragment
|
32 |
+
def sidebar_fragment():
|
33 |
+
st.markdown("""<div class="sidebar-intro">
|
34 |
+
<p><strong>๐ Supported Tasks</strong>
|
35 |
+
<p>Automatic Speech Recognation</p>
|
36 |
+
<p>Speech Translation</p>
|
37 |
+
<p>Spoken Question Answering</p>
|
38 |
+
<p>Spoken Dialogue Summarization</p>
|
39 |
+
<p>Speech Instruction</p>
|
40 |
+
<p>Paralinguistics</p>
|
41 |
+
<br>
|
42 |
+
<p><strong>๐ Generation Config</strong>
|
43 |
+
</div>""", unsafe_allow_html=True)
|
44 |
+
|
45 |
+
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
|
46 |
+
|
47 |
+
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
|
48 |
+
|
49 |
+
@st.fragment
|
50 |
+
def specify_audio_fragment():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
col1, col2, col3 = st.columns([3.5, 4, 1.5])
|
52 |
|
53 |
with col1:
|
|
|
61 |
options=audio_sample_names,
|
62 |
index=None,
|
63 |
placeholder="Select an audio sample:",
|
64 |
+
on_change=lambda: st.session_state.update(on_select=True),
|
65 |
key='select')
|
66 |
|
67 |
if sample_name and st.session_state.on_select:
|
|
|
78 |
label="**Upload Audio:**",
|
79 |
label_visibility="collapsed",
|
80 |
type=['wav', 'mp3'],
|
81 |
+
on_change=lambda: st.session_state.update(on_upload=True),
|
82 |
key='upload'
|
83 |
)
|
84 |
|
|
|
97 |
stop_prompt="๐ด stop recording",
|
98 |
format="wav",
|
99 |
use_container_width=True,
|
100 |
+
callback=lambda: st.session_state.update(on_record=True),
|
101 |
key='record')
|
102 |
|
103 |
if recording and st.session_state.on_record:
|
|
|
106 |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
107 |
st.session_state.audio_array = bytes_to_array(audio_bytes)
|
108 |
|
109 |
+
st.session_state.update(on_upload=False, on_record=False, on_select=False)
|
|
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if st.session_state.audio_array.size:
|
112 |
with st.chat_message("user"):
|
113 |
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
|
114 |
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
|
115 |
|
116 |
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
|
117 |
+
|
|
|
118 |
for i, inst in enumerate(st.session_state.default_instruction):
|
119 |
st.button(
|
120 |
f"**Example Instruction {i+1}**: {inst}",
|
121 |
args=(inst,),
|
122 |
disabled=st.session_state.disprompt,
|
123 |
+
on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True)
|
124 |
)
|
125 |
|
126 |
+
if st.session_state.on_click_button:
|
127 |
+
st.session_state.on_click_button = False
|
128 |
+
st.rerun(scope="app")
|
129 |
+
|
130 |
+
|
131 |
+
def dialogue_section():
|
132 |
for message in st.session_state.messages[-2:]:
|
133 |
with st.chat_message(message["role"]):
|
134 |
if message.get("error"):
|
|
|
137 |
st.warning(warning_msg)
|
138 |
if message.get("content"):
|
139 |
st.write(message["content"])
|
140 |
+
|
141 |
+
if chat_input := st.chat_input(
|
142 |
placeholder="Type Your Instruction Here",
|
143 |
disabled=st.session_state.disprompt,
|
144 |
on_submit=lambda: st.session_state.update(disprompt=True)
|
145 |
):
|
146 |
+
st.session_state.new_prompt = chat_input
|
147 |
+
|
148 |
+
if one_time_prompt := st.session_state.new_prompt:
|
149 |
+
st.session_state.new_prompt = ""
|
150 |
|
|
|
151 |
with st.chat_message("user"):
|
152 |
+
st.write(one_time_prompt)
|
153 |
+
st.session_state.messages.append({"role": "user", "content": one_time_prompt})
|
154 |
|
155 |
with st.chat_message("assistant"):
|
156 |
with st.spinner("Thinking..."):
|
157 |
error_msg, warnings, response = "", [], ""
|
158 |
try:
|
159 |
+
response, warnings = retry_generate_response(one_time_prompt)
|
160 |
except NoAudioException:
|
161 |
error_msg = "Please specify audio first!"
|
162 |
except TunnelNotRunningException:
|
|
|
170 |
"content": response
|
171 |
})
|
172 |
|
173 |
+
st.session_state.disprompt=False
|
174 |
+
st.rerun(scope="app")
|
175 |
+
|
176 |
+
|
177 |
+
def audio_llm():
|
178 |
+
if "server" not in st.session_state:
|
179 |
+
st.session_state.server = start_server()
|
180 |
+
|
181 |
+
if "client" not in st.session_state or 'model_name' not in st.session_state:
|
182 |
+
st.session_state.client, st.session_state.model_name = load_model()
|
183 |
+
|
184 |
+
for key, value in DEFAULT_DIALOGUE_STATES.items():
|
185 |
+
if key not in st.session_state:
|
186 |
+
st.session_state[key]=value
|
187 |
+
|
188 |
+
with st.sidebar:
|
189 |
+
sidebar_fragment()
|
190 |
+
|
191 |
+
if st.sidebar.button('Clear History'):
|
192 |
+
st.session_state.update(DEFAULT_DIALOGUE_STATES)
|
193 |
+
|
194 |
+
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot ๐ค</h1>", unsafe_allow_html=True)
|
195 |
+
st.markdown(
|
196 |
+
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
|
197 |
+
developed by I2R, A*STAR, in collaboration with AISG, Singapore.
|
198 |
+
It is tailored for Singaporeโs multilingual and multicultural landscape."""
|
199 |
+
)
|
200 |
+
|
201 |
+
specify_audio_fragment()
|
202 |
+
dialogue_section()
|