Spaces:
Running
Running
import os | |
import base64 | |
import numpy as np | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from streamlit_mic_recorder import mic_recorder | |
from utils import load_model, generate_response, bytes_to_array, start_server | |
def home_page(): | |
## Set up home page Title | |
col1, col2 = st.columns([1, 4]) | |
custom_html = """ | |
<div class="banner"> | |
<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRhB2e_AhOe11wKxnnwOmOVg9E7J1MBgiTeYzzFAESwcCP5IbBAc2X8BwGChMfJzwqtVg&usqp=CAU" alt="Banner Image"> | |
</div> | |
<style> | |
.banner { | |
width: 100%; | |
height: 200px; | |
overflow: visible; | |
} | |
.banner img { | |
width: 100%; | |
object-fit: cover; | |
} | |
</style> | |
""" | |
with col1: | |
components.html(custom_html) | |
with col2: | |
st.write("# Welcome to MERaLiON - AudioLLMs 🤖") | |
## Set up home page other information | |
st.markdown('') | |
def audio_llm(): | |
with st.sidebar: | |
st.divider() | |
st.markdown("""<div class="sidebar-intro"> | |
<p><strong>Purpose</strong>: Complex Audio Understanding</p> | |
<p><strong>Name</strong>: MERaLiON-AudioLLM-v1</p> | |
<p><strong>Version</strong>: Dec. 20, 2024</p> | |
</div>""", unsafe_allow_html=True) | |
if st.sidebar.button('Clear History'): | |
st.session_state.update(messages=[], | |
on_upload=False, | |
on_record=False, | |
on_select=False, | |
audio_array=np.array([])) | |
if "server" not in st.session_state: | |
st.session_state.server = start_server() | |
if "client" not in st.session_state or 'model_name' not in st.session_state: | |
st.session_state.client, st.session_state.model_name = load_model() | |
if "audio_array" not in st.session_state: | |
st.session_state.audio_base64 = '' | |
st.session_state.audio_array = np.array([]) | |
if "default_instruction" not in st.session_state: | |
st.session_state.default_instruction = "" | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("**Record Audio:**") | |
recording = mic_recorder( | |
format="wav", | |
use_container_width=True, | |
callback=lambda: st.session_state.update(on_record=True, messages=[]), | |
key='record') | |
if recording and st.session_state.on_record: | |
audio_bytes = recording["bytes"] | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
with col2: | |
st.markdown("**Upload Audio:**") | |
uploaded_file = st.file_uploader( | |
label="**Upload Audio:**", | |
label_visibility="collapsed", | |
type=['wav', 'mp3'], | |
on_change=lambda: st.session_state.update(on_upload=True, messages=[]), | |
key='upload' | |
) | |
if uploaded_file and st.session_state.on_upload: | |
audio_bytes = uploaded_file.read() | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
with col3: | |
audio_samples_w_instruct = { | |
'1_ASR_IMDA_PART1_ASR_v2_141' : "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.", | |
'2_ASR_IMDA_PART1_ASR_v2_2258': "- Turn the spoken language into a text format.\n\n- Please translate the content into Chinese.", | |
'3_ASR_IMDA_PART1_ASR_v2_2265': "- Turn the spoken language into a text format.", | |
'4_ASR_IMDA_PART2_ASR_v2_999' : "- Translate the spoken words into text format.", | |
'5_ASR_IMDA_PART2_ASR_v2_2241': "- Translate the spoken words into text format.", | |
'6_ASR_IMDA_PART2_ASR_v2_3409': "- Translate the spoken words into text format.", | |
'7_ASR_IMDA_PART3_30_ASR_v2_2269': "- Need this talk written down, please.", | |
'8_ASR_IMDA_PART3_30_ASR_v2_1698': "- Need this talk written down, please.", | |
'9_ASR_IMDA_PART3_30_ASR_v2_2474': "- Need this talk written down, please.", | |
'10_ASR_IMDA_PART4_30_ASR_v2_1527': "- Write out the dialogue as text.", | |
'11_ASR_IMDA_PART4_30_ASR_v2_3771': "- Write out the dialogue as text.", | |
'12_ASR_IMDA_PART4_30_ASR_v2_103' : "- Write out the dialogue as text.", | |
'13_ASR_IMDA_PART5_30_ASR_v2_1446': "- Translate this vocal recording into a textual format.", | |
'14_ASR_IMDA_PART5_30_ASR_v2_2281': "- Translate this vocal recording into a textual format.", | |
'15_ASR_IMDA_PART5_30_ASR_v2_4388': "- Translate this vocal recording into a textual format.", | |
'16_ASR_IMDA_PART6_30_ASR_v2_576': "- Record the spoken word in text form.", | |
'17_ASR_IMDA_PART6_30_ASR_v2_1413': "- Record the spoken word in text form.", | |
'18_ASR_IMDA_PART6_30_ASR_v2_2834': "- Record the spoken word in text form.", | |
'19_ASR_AIShell_zh_ASR_v2_5044': "- Transform the oral presentation into a text document.", | |
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': "- Please provide a written transcription of the speech." | |
} | |
audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()] | |
st.markdown("**Select Audio:**") | |
sample_name = st.selectbox( | |
label="**Select Audio:**", | |
label_visibility="collapsed", | |
options=audio_sample_names, | |
index=None, | |
placeholder="Select an audio sample:", | |
on_change=lambda: st.session_state.update(on_select=True, messages=[]), | |
key='select') | |
if sample_name and st.session_state.on_select: | |
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read() | |
st.session_state.default_instruction = audio_samples_w_instruct[sample_name] | |
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
st.session_state.audio_array = bytes_to_array(audio_bytes) | |
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000) | |
st.session_state.update(on_upload=False, on_record=False, on_select=False) | |
st.markdown("**Model Configuration:**") | |
col4, col5, _ = st.columns(3) | |
with col4: | |
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature') | |
with col5: | |
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p') | |
st.markdown("**Example Instruction:**") | |
st.write(st.session_state.default_instruction) | |
st.markdown( | |
""" | |
<style> | |
.st-emotion-cache-1c7y2kd { | |
flex-direction: row-reverse; | |
text-align: right; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if prompt := st.chat_input(placeholder="Your Instruction"): | |
with st.chat_message("user"): | |
st.write(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
stream = generate_response(prompt) | |
response = st.write_stream(stream) | |
except Exception as e: | |
response = f"Caught Exception: {repr(e)}. Please contact the administrator to restart this space." | |
st.write(response) | |
raise(e) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |