MERaLiON-AudioLLM / pages.py
YingxuHe's picture
Update pages.py
67da2ee verified
raw
history blame
8.03 kB
import os
import base64
import numpy as np
import streamlit as st
import streamlit.components.v1 as components
from streamlit_mic_recorder import mic_recorder
from utils import load_model, generate_response, bytes_to_array, start_server
def home_page():
## Set up home page Title
col1, col2 = st.columns([1, 4])
custom_html = """
<div class="banner">
<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRhB2e_AhOe11wKxnnwOmOVg9E7J1MBgiTeYzzFAESwcCP5IbBAc2X8BwGChMfJzwqtVg&usqp=CAU" alt="Banner Image">
</div>
<style>
.banner {
width: 100%;
height: 200px;
overflow: visible;
}
.banner img {
width: 100%;
object-fit: cover;
}
</style>
"""
with col1:
components.html(custom_html)
with col2:
st.write("# Welcome to MERaLiON - AudioLLMs 🤖")
## Set up home page other information
st.markdown('')
def audio_llm():
with st.sidebar:
st.divider()
st.markdown("""<div class="sidebar-intro">
<p><strong>Purpose</strong>: Complex Audio Understanding</p>
<p><strong>Name</strong>: MERaLiON-AudioLLM-v1</p>
<p><strong>Version</strong>: Dec. 20, 2024</p>
</div>""", unsafe_allow_html=True)
if st.sidebar.button('Clear History'):
st.session_state.update(messages=[],
on_upload=False,
on_record=False,
on_select=False,
audio_array=np.array([]))
if "server" not in st.session_state:
st.session_state.server = start_server()
if "client" not in st.session_state or 'model_name' not in st.session_state:
st.session_state.client, st.session_state.model_name = load_model()
if "audio_array" not in st.session_state:
st.session_state.audio_base64 = ''
st.session_state.audio_array = np.array([])
if "default_instruction" not in st.session_state:
st.session_state.default_instruction = ""
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Record Audio:**")
recording = mic_recorder(
format="wav",
use_container_width=True,
callback=lambda: st.session_state.update(on_record=True, messages=[]),
key='record')
if recording and st.session_state.on_record:
audio_bytes = recording["bytes"]
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col2:
st.markdown("**Upload Audio:**")
uploaded_file = st.file_uploader(
label="**Upload Audio:**",
label_visibility="collapsed",
type=['wav', 'mp3'],
on_change=lambda: st.session_state.update(on_upload=True, messages=[]),
key='upload'
)
if uploaded_file and st.session_state.on_upload:
audio_bytes = uploaded_file.read()
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col3:
audio_samples_w_instruct = {
'1_ASR_IMDA_PART1_ASR_v2_141' : "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.",
'2_ASR_IMDA_PART1_ASR_v2_2258': "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.",
'3_ASR_IMDA_PART1_ASR_v2_2265': "- Turn the spoken language into a text format.",
'4_ASR_IMDA_PART2_ASR_v2_999' : "- Translate the spoken words into text format.",
'5_ASR_IMDA_PART2_ASR_v2_2241': "- Translate the spoken words into text format.",
'6_ASR_IMDA_PART2_ASR_v2_3409': "- Translate the spoken words into text format.",
'7_ASR_IMDA_PART3_30_ASR_v2_2269': "- Need this talk written down, please.",
'8_ASR_IMDA_PART3_30_ASR_v2_1698': "- Need this talk written down, please.",
'9_ASR_IMDA_PART3_30_ASR_v2_2474': "- Need this talk written down, please.",
'10_ASR_IMDA_PART4_30_ASR_v2_1527': "- Write out the dialogue as text.",
'11_ASR_IMDA_PART4_30_ASR_v2_3771': "- Write out the dialogue as text.",
'12_ASR_IMDA_PART4_30_ASR_v2_103' : "- Write out the dialogue as text.",
'13_ASR_IMDA_PART5_30_ASR_v2_1446': "- Translate this vocal recording into a textual format.",
'14_ASR_IMDA_PART5_30_ASR_v2_2281': "- Translate this vocal recording into a textual format.",
'15_ASR_IMDA_PART5_30_ASR_v2_4388': "- Translate this vocal recording into a textual format.",
'16_ASR_IMDA_PART6_30_ASR_v2_576': "- Record the spoken word in text form.",
'17_ASR_IMDA_PART6_30_ASR_v2_1413': "- Record the spoken word in text form.",
'18_ASR_IMDA_PART6_30_ASR_v2_2834': "- Record the spoken word in text form.",
'19_ASR_AIShell_zh_ASR_v2_5044': "- Transform the oral presentation into a text document.",
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': "- Please provide a written transcription of the speech."
}
audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
st.markdown("**Select Audio:**")
sample_name = st.selectbox(
label="**Select Audio:**",
label_visibility="collapsed",
options=audio_sample_names,
index=None,
placeholder="Select an audio sample:",
on_change=lambda: st.session_state.update(on_select=True, messages=[]),
key='select')
if sample_name and st.session_state.on_select:
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
st.session_state.default_instruction = audio_samples_w_instruct[sample_name]
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
st.session_state.update(on_upload=False, on_record=False, on_select=False)
st.markdown("**Model Configuration:**")
col4, col5, _ = st.columns(3)
with col4:
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
with col5:
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
st.markdown("**Example Instruction:**")
st.write(st.session_state.default_instruction)
st.markdown(
"""
<style>
.st-emotion-cache-1c7y2kd {
flex-direction: row-reverse;
text-align: right;
}
</style>
""",
unsafe_allow_html=True,
)
if "messages" not in st.session_state:
st.session_state.messages = []
if prompt := st.chat_input(placeholder="Your Instruction"):
with st.chat_message("user"):
st.write(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
stream = generate_response(prompt)
response = st.write_stream(stream)
except Exception as e:
response = f"Caught Exception: {repr(e)}. Please contact the administrator to restart this space."
st.write(response)
raise(e)
st.session_state.messages.append({"role": "assistant", "content": response})