import copy import base64 import numpy as np import streamlit as st from streamlit_mic_recorder import mic_recorder from utils import ( GENERAL_INSTRUCTIONS, AUDIO_SAMPLES_W_INSTRUCT, NoAudioException, TunnelNotRunningException, retry_generate_response, load_model, bytes_to_array, start_server, ) DEFAULT_DIALOGUE_STATES = dict( default_instruction=[], audio_base64='', audio_array=np.array([]), disprompt = False, new_prompt = "", messages=[], on_select=False, on_upload=False, on_record=False, on_click_button = False ) @st.fragment def sidebar_fragment(): st.markdown("""
""", unsafe_allow_html=True) st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature') st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p') @st.fragment def specify_audio_fragment(): col1, col2, col3 = st.columns([3.5, 4, 1.5]) with col1: audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()] st.markdown("**Select Audio From Examples:**") sample_name = st.selectbox( label="**Select Audio:**", label_visibility="collapsed", options=audio_sample_names, index=None, placeholder="Select an audio sample:", on_change=lambda: st.session_state.update(on_select=True), key='select') if sample_name and st.session_state.on_select: audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read() st.session_state.default_instruction = AUDIO_SAMPLES_W_INSTRUCT[sample_name] st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') st.session_state.audio_array = bytes_to_array(audio_bytes) with col2: st.markdown("or **Upload Audio:**") uploaded_file = st.file_uploader( label="**Upload Audio:**", label_visibility="collapsed", type=['wav', 'mp3'], on_change=lambda: st.session_state.update(on_upload=True), key='upload' ) if uploaded_file and st.session_state.on_upload: audio_bytes = uploaded_file.read() st.session_state.default_instruction = GENERAL_INSTRUCTIONS st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') st.session_state.audio_array = bytes_to_array(audio_bytes) with col3: st.markdown("or **Record Audio:**") recording = mic_recorder( start_prompt="▶ start recording", stop_prompt="🔴 stop recording", format="wav", use_container_width=True, callback=lambda: st.session_state.update(on_record=True), key='record') if recording and st.session_state.on_record: audio_bytes = recording["bytes"] st.session_state.default_instruction = GENERAL_INSTRUCTIONS st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') st.session_state.audio_array = bytes_to_array(audio_bytes) st.session_state.update(on_upload=False, on_record=False, on_select=False) if st.session_state.audio_array.size: with st.chat_message("user"): if st.session_state.audio_array.shape[0] / 16000 > 30.0: st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.") st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000) for i, inst in enumerate(st.session_state.default_instruction): st.button( f"**Example Instruction {i+1}**: {inst}", args=(inst,), disabled=st.session_state.disprompt, on_click=lambda p: st.session_state.update(disprompt=True, new_prompt=p, on_click_button=True) ) if st.session_state.on_click_button: st.session_state.on_click_button = False st.rerun(scope="app") def dialogue_section(): for message in st.session_state.messages: with st.chat_message(message["role"]): if message.get("error"): st.error(message["error"]) for warning_msg in message.get("warnings", []): st.warning(warning_msg) if message.get("content"): st.write(message["content"]) if chat_input := st.chat_input( placeholder="Type Your Instruction Here", disabled=st.session_state.disprompt, on_submit=lambda: st.session_state.update(disprompt=True) ): st.session_state.new_prompt = chat_input if one_time_prompt := st.session_state.new_prompt: st.session_state.update(new_prompt="", messages=[]) with st.chat_message("user"): st.write(one_time_prompt) st.session_state.messages.append({"role": "user", "content": one_time_prompt}) with st.chat_message("assistant"): with st.spinner("Thinking..."): error_msg, warnings, response = "", [], "" try: response, warnings = retry_generate_response(one_time_prompt) except NoAudioException: error_msg = "Please specify audio first!" except TunnelNotRunningException: error_msg = "Internet connection cannot be established. Please contact the administrator." except Exception as e: error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator." st.session_state.messages.append({ "role": "assistant", "error": error_msg, "warnings": warnings, "content": response }) st.session_state.disprompt=False st.rerun(scope="app") def audio_llm(): if "server" not in st.session_state: st.session_state.server = start_server() if "client" not in st.session_state or 'model_name' not in st.session_state: st.session_state.client, st.session_state.model_name = load_model() for key, value in DEFAULT_DIALOGUE_STATES.items(): if key not in st.session_state: st.session_state[key]=copy.deepcopy(value) with st.sidebar: sidebar_fragment() if st.sidebar.button('Clear History'): st.session_state.update(DEFAULT_DIALOGUE_STATES) st.markdown("