Spaces:
Running
Running
File size: 7,621 Bytes
794b9ed e531eef 552600b 00b13d9 552600b e531eef b4ca523 e531eef b4ca523 4bd0ef1 e531eef 552600b e531eef c050149 aee1475 e531eef aee1475 e531eef b4ca523 e531eef 552600b e531eef 6a8f361 b4ca523 6a8f361 5f13969 6a8f361 73ade6e 6a8f361 b4ca523 6a8f361 5f13969 6a8f361 e531eef b4ca523 938f89c 7234479 0f41366 b4ca523 7234479 b4ca523 7234479 938f89c b4ca523 794b9ed 938f89c 6a8f361 b4ca523 c050149 938f89c b4ca523 794b9ed 7234479 e531eef b4ca523 e531eef 6a8f361 00b13d9 6a8f361 b4ca523 6a8f361 00b13d9 6a8f361 7234479 b4ca523 794b9ed b4ca523 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import copy
import base64
import numpy as np
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from utils import (
GENERAL_INSTRUCTIONS,
AUDIO_SAMPLES_W_INSTRUCT,
NoAudioException,
TunnelNotRunningException,
retry_generate_response,
load_model,
bytes_to_array,
start_server,
)
DEFAULT_DIALOGUE_STATES = dict(
default_instruction=[],
audio_base64='',
audio_array=np.array([]),
disprompt = False,
new_prompt = "",
messages=[],
on_select=False,
on_upload=False,
on_record=False,
on_click_button = False
)
@st.fragment
def sidebar_fragment():
st.markdown("""<div class="sidebar-intro">
<p><strong>📌 Supported Tasks</strong>
<p>Automatic Speech Recognation</p>
<p>Speech Translation</p>
<p>Spoken Question Answering</p>
<p>Spoken Dialogue Summarization</p>
<p>Speech Instruction</p>
<p>Paralinguistics</p>
<br>
<p><strong>📎 Generation Config</strong>
</div>""", unsafe_allow_html=True)
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
@st.fragment
def specify_audio_fragment():
col1, col2, col3 = st.columns([3.5, 4, 1.5])
with col1:
audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
st.markdown("**Select Audio From Examples:**")
sample_name = st.selectbox(
label="**Select Audio:**",
label_visibility="collapsed",
options=audio_sample_names,
index=None,
placeholder="Select an audio sample:",
on_change=lambda: st.session_state.update(on_select=True),
key='select')
if sample_name and st.session_state.on_select:
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
st.session_state.default_instruction = AUDIO_SAMPLES_W_INSTRUCT[sample_name]
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col2:
st.markdown("or **Upload Audio:**")
uploaded_file = st.file_uploader(
label="**Upload Audio:**",
label_visibility="collapsed",
type=['wav', 'mp3'],
on_change=lambda: st.session_state.update(on_upload=True),
key='upload'
)
if uploaded_file and st.session_state.on_upload:
audio_bytes = uploaded_file.read()
st.session_state.default_instruction = GENERAL_INSTRUCTIONS
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col3:
st.markdown("or **Record Audio:**")
recording = mic_recorder(
start_prompt="▶ start recording",
stop_prompt="🔴 stop recording",
format="wav",
use_container_width=True,
callback=lambda: st.session_state.update(on_record=True),
key='record')
if recording and st.session_state.on_record:
audio_bytes = recording["bytes"]
st.session_state.default_instruction = GENERAL_INSTRUCTIONS
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
st.session_state.update(on_upload=False, on_record=False, on_select=False)
if st.session_state.audio_array.size:
with st.chat_message("user"):
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
for i, inst in enumerate(st.session_state.default_instruction):
st.button(
f"**Example Instruction {i+1}**: {inst}",
args=(inst,),
disabled=st.session_state.disprompt,
on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True)
)
if st.session_state.on_click_button:
st.session_state.on_click_button = False
st.rerun(scope="app")
def dialogue_section():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if message.get("content"):
st.write(message["content"])
if chat_input := st.chat_input(
placeholder="Type Your Instruction Here",
disabled=st.session_state.disprompt,
on_submit=lambda: st.session_state.update(disprompt=True)
):
st.session_state.new_prompt = chat_input
if one_time_prompt := st.session_state.new_prompt:
st.session_state.update(new_prompt="", messages=[])
with st.chat_message("user"):
st.write(one_time_prompt)
st.session_state.messages.append({"role": "user", "content": one_time_prompt})
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
error_msg, warnings, response = "", [], ""
try:
response, warnings = retry_generate_response(one_time_prompt)
except NoAudioException:
error_msg = "Please specify audio first!"
except TunnelNotRunningException:
error_msg = "Internet connection cannot be established. Please contact the administrator."
except Exception as e:
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
st.session_state.messages.append({
"role": "assistant",
"error": error_msg,
"warnings": warnings,
"content": response
})
st.session_state.disprompt=False
st.rerun(scope="app")
def audio_llm():
if "server" not in st.session_state:
st.session_state.server = start_server()
if "client" not in st.session_state or 'model_name' not in st.session_state:
st.session_state.client, st.session_state.model_name = load_model()
for key, value in DEFAULT_DIALOGUE_STATES.items():
if key not in st.session_state:
st.session_state[key]=copy.deepcopy(value)
with st.sidebar:
sidebar_fragment()
if st.sidebar.button('Clear History'):
st.session_state.update(DEFAULT_DIALOGUE_STATES)
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
st.markdown(
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
developed by I2R, A*STAR, in collaboration with AISG, Singapore.
It is tailored for Singapore’s multilingual and multicultural landscape."""
)
specify_audio_fragment()
dialogue_section()
|