import io import os import re import librosa import paramiko import streamlit as st from openai import OpenAI from sshtunnel import SSHTunnelForwarder local_port = int(os.getenv('LOCAL_PORT')) class NoAudioException(Exception): pass @st.cache_resource() def start_server(): pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY'))) server = SSHTunnelForwarder( ssh_address_or_host=os.getenv('SERVER_DNS_NAME'), ssh_username="ec2-user", ssh_pkey=pkey, local_bind_address=("127.0.0.1", local_port), remote_bind_address=("127.0.0.1", 8000) ) server.start() return server @st.cache_resource() def load_model(): openai_api_key = os.getenv('API_KEY') openai_api_base = f"http://localhost:{local_port}/v1" client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) models = client.models.list() model_name = models.data[0].id return client, model_name def generate_response(text_input): if not st.session_state.audio_base64: raise NoAudioException("audio is empty.") warnings = [] if re.search("tool|code|python|java|math|calculate", text_input): warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.") if re.search(r'[\u4e00-\u9fff]+', text_input): warnings.append("NOTE: Please try to prompt in English for the best performance.") stream = st.session_state.client.chat.completions.create( messages=[{ "role": "user", "content": [ { "type": "text", "text": f"Text instruction: {text_input}" }, { "type": "audio_url", "audio_url": { "url": f"data:audio/ogg;base64,{st.session_state.audio_base64}" }, }, ], }], model=st.session_state.model_name, max_completion_tokens=512, temperature=st.session_state.temperature, top_p=st.session_state.top_p, stream=True, ) return stream, warnings def bytes_to_array(audio_bytes): audio_array, _ = librosa.load( io.BytesIO(audio_bytes), sr=16000 ) return audio_array