MERaLiON-AudioLLM / pages.py
YingxuHe's picture
Update pages.py
1494206 verified
raw
history blame
7.64 kB
import copy
import base64
import numpy as np
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from utils import (
GENERAL_INSTRUCTIONS,
AUDIO_SAMPLES_W_INSTRUCT,
NoAudioException,
TunnelNotRunningException,
retry_generate_response,
load_model,
bytes_to_array,
start_server,
)
DEFAULT_DIALOGUE_STATES = dict(
default_instruction=[],
audio_base64='',
audio_array=np.array([]),
disprompt = False,
new_prompt = "",
messages=[],
on_select=False,
on_upload=False,
on_record=False,
on_click_button = False
)
@st.fragment
def sidebar_fragment():
st.markdown("""<div class="sidebar-intro">
<p><strong>📌 Supported Tasks</strong>
<p>Automatic Speech Recognation</p>
<p>Speech Translation</p>
<p>Spoken Question Answering</p>
<p>Spoken Dialogue Summarization</p>
<p>Speech Instruction</p>
<p>Paralinguistics</p>
<br>
<p><strong>📎 Generation Config</strong>
</div>""", unsafe_allow_html=True)
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
@st.fragment
def specify_audio_fragment():
col1, col2, col3 = st.columns([4, 2, 2])
with col1:
audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
st.markdown("**Select Audio From Examples:**")
sample_name = st.selectbox(
label="**Select Audio:**",
label_visibility="collapsed",
options=audio_sample_names,
index=None,
placeholder="Select an audio sample:",
on_change=lambda: st.session_state.update(on_select=True),
key='select')
if sample_name and st.session_state.on_select:
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
st.session_state.default_instruction = AUDIO_SAMPLES_W_INSTRUCT[sample_name]
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col2:
st.markdown("or **Upload Audio:**")
uploaded_file = st.file_uploader(
label="**Upload Audio:**",
label_visibility="collapsed",
type=['wav', 'mp3'],
on_change=lambda: st.session_state.update(on_upload=True),
key='upload'
)
if uploaded_file and st.session_state.on_upload:
audio_bytes = uploaded_file.read()
st.session_state.default_instruction = GENERAL_INSTRUCTIONS
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col3:
st.markdown("or **Record Audio:**")
recording = mic_recorder(
start_prompt="▶ start recording",
stop_prompt="🔴 stop recording",
format="wav",
use_container_width=True,
callback=lambda: st.session_state.update(on_record=True),
key='record')
if recording and st.session_state.on_record:
audio_bytes = recording["bytes"]
st.session_state.default_instruction = GENERAL_INSTRUCTIONS
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
st.session_state.update(on_upload=False, on_record=False, on_select=False)
if st.session_state.audio_array.size:
with st.chat_message("user"):
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
for i, inst in enumerate(st.session_state.default_instruction):
st.button(
f"**Example Instruction {i+1}**: {inst}",
args=(inst,),
disabled=st.session_state.disprompt,
on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True, messages=[])
)
if st.session_state.on_click_button:
st.session_state.on_click_button = False
st.rerun(scope="app")
def dialogue_section():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if message.get("content"):
st.write(message["content"])
if chat_input := st.chat_input(
placeholder="Type Your Instruction Here",
disabled=st.session_state.disprompt,
on_submit=lambda: st.session_state.update(disprompt=True, messages=[])
):
st.session_state.new_prompt = chat_input
if one_time_prompt := st.session_state.new_prompt:
st.session_state.update(new_prompt="", messages=[])
with st.chat_message("user"):
st.write(one_time_prompt)
st.session_state.messages.append({"role": "user", "content": one_time_prompt})
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
error_msg, warnings, response = "", [], ""
try:
response, warnings = retry_generate_response(one_time_prompt)
except NoAudioException:
error_msg = "Please specify audio first!"
except TunnelNotRunningException:
error_msg = "Internet connection cannot be established. Please contact the administrator."
except Exception as e:
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
st.session_state.messages.append({
"role": "assistant",
"error": error_msg,
"warnings": warnings,
"content": response
})
st.session_state.disprompt=False
st.rerun(scope="app")
def audio_llm():
if "server" not in st.session_state:
st.session_state.server = start_server()
if "client" not in st.session_state or 'model_name' not in st.session_state:
st.session_state.client, st.session_state.model_name = load_model()
for key, value in DEFAULT_DIALOGUE_STATES.items():
if key not in st.session_state:
st.session_state[key]=copy.deepcopy(value)
with st.sidebar:
sidebar_fragment()
if st.sidebar.button('Clear History'):
st.session_state.update(DEFAULT_DIALOGUE_STATES)
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True)
st.markdown(
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
developed by I2R, A*STAR, in collaboration with AISG, Singapore.
It is tailored for Singapore’s multilingual and multicultural landscape."""
)
specify_audio_fragment()
dialogue_section()