Spaces:
Running
Running
import copy | |
import base64 | |
import numpy as np | |
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from utils import ( | |
GENERAL_INSTRUCTIONS, | |
AUDIO_SAMPLES_W_INSTRUCT, | |
NoAudioException, | |
TunnelNotRunningException, | |
retry_generate_response, | |
load_model, | |
bytes_to_array, | |
start_server, | |
) | |
DEFAULT_DIALOGUE_STATES = dict( | |
default_instruction=[], | |
audio_base64='', | |
audio_array=np.array([]), | |
disprompt = False, | |
new_prompt = "", | |
messages=[], | |
on_select=False, | |
on_upload=False, | |
on_record=False, | |
on_click_button = False | |
) | |
def sidebar_fragment(): | |
st.markdown("""<div class="sidebar-intro"> | |
<p><strong>📌 Supported Tasks</strong> | |
<p>Automatic Speech Recognation</p> | |
<p>Speech Translation</p> | |
<p>Spoken Question Answering</p> | |
<p>Spoken Dialogue Summarization</p> | |
<p>Speech Instruction</p> | |
<p>Paralinguistics</p> | |
<br> | |
<p><strong>📎 Generation Config</strong> | |
</div>""", unsafe_allow_html=True) | |
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature') | |
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p') | |
def specify_audio_fragment(): | |
col1, col2, col3 = st.columns([4, 2, 2]) | |
with col1: | |
audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()] | |
st.markdown("**Select Audio From Examples:**") | |
sample_name = st.selectbox( | |
label="**Select Audio:**", | |
label_visibility="collapsed", | |
options=audio_sample_names, | |
index=None, | |
placeholder="Select an audio sample:", | |
on_change=lambda: st.session_state.update(on_select=True), | |
key='select') | |
if sample_name and st.session_state.on_select: | |
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read() | |
st.session_state.default_instruction = AUDIO_SAMPLES_W_INSTRUCT[sample_name] | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
with col2: | |
st.markdown("or **Upload Audio:**") | |
uploaded_file = st.file_uploader( | |
label="**Upload Audio:**", | |
label_visibility="collapsed", | |
type=['wav', 'mp3'], | |
on_change=lambda: st.session_state.update(on_upload=True), | |
key='upload' | |
) | |
if uploaded_file and st.session_state.on_upload: | |
audio_bytes = uploaded_file.read() | |
st.session_state.default_instruction = GENERAL_INSTRUCTIONS | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
with col3: | |
st.markdown("or **Record Audio:**") | |
recording = mic_recorder( | |
start_prompt="▶ start recording", | |
stop_prompt="🔴 stop recording", | |
format="wav", | |
use_container_width=True, | |
callback=lambda: st.session_state.update(on_record=True), | |
key='record') | |
if recording and st.session_state.on_record: | |
audio_bytes = recording["bytes"] | |
st.session_state.default_instruction = GENERAL_INSTRUCTIONS | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
st.session_state.update(on_upload=False, on_record=False, on_select=False) | |
if st.session_state.audio_array.size: | |
with st.chat_message("user"): | |
if st.session_state.audio_array.shape[0] / 16000 > 30.0: | |
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.") | |
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000) | |
for i, inst in enumerate(st.session_state.default_instruction): | |
st.button( | |
f"**Example Instruction {i+1}**: {inst}", | |
args=(inst,), | |
disabled=st.session_state.disprompt, | |
on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True, messages=[]) | |
) | |
if st.session_state.on_click_button: | |
st.session_state.on_click_button = False | |
st.rerun(scope="app") | |
def dialogue_section(): | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
if message.get("error"): | |
st.error(message["error"]) | |
for warning_msg in message.get("warnings", []): | |
st.warning(warning_msg) | |
if message.get("content"): | |
st.write(message["content"]) | |
if chat_input := st.chat_input( | |
placeholder="Type Your Instruction Here", | |
disabled=st.session_state.disprompt, | |
on_submit=lambda: st.session_state.update(disprompt=True, messages=[]) | |
): | |
st.session_state.new_prompt = chat_input | |
if one_time_prompt := st.session_state.new_prompt: | |
st.session_state.update(new_prompt="", messages=[]) | |
with st.chat_message("user"): | |
st.write(one_time_prompt) | |
st.session_state.messages.append({"role": "user", "content": one_time_prompt}) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
error_msg, warnings, response = "", [], "" | |
try: | |
response, warnings = retry_generate_response(one_time_prompt) | |
except NoAudioException: | |
error_msg = "Please specify audio first!" | |
except TunnelNotRunningException: | |
error_msg = "Internet connection cannot be established. Please contact the administrator." | |
except Exception as e: | |
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator." | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"error": error_msg, | |
"warnings": warnings, | |
"content": response | |
}) | |
st.session_state.disprompt=False | |
st.rerun(scope="app") | |
def audio_llm(): | |
if "server" not in st.session_state: | |
st.session_state.server = start_server() | |
if "client" not in st.session_state or 'model_name' not in st.session_state: | |
st.session_state.client, st.session_state.model_name = load_model() | |
for key, value in DEFAULT_DIALOGUE_STATES.items(): | |
if key not in st.session_state: | |
st.session_state[key]=copy.deepcopy(value) | |
with st.sidebar: | |
sidebar_fragment() | |
if st.sidebar.button('Clear History'): | |
st.session_state.update(DEFAULT_DIALOGUE_STATES) | |
st.markdown("<h1 style='text-align: center; color: black;'>MERaLiON-AudioLLM ChatBot 🤖</h1>", unsafe_allow_html=True) | |
st.markdown( | |
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION), | |
developed by I2R, A*STAR, in collaboration with AISG, Singapore. | |
It is tailored for Singapore’s multilingual and multicultural landscape.""" | |
) | |
specify_audio_fragment() | |
dialogue_section() |