MERaLiON-AudioLLM / utils.py
YingxuHe's picture
Update utils.py
d1ebb9a verified
raw
history blame
2.41 kB
import io
import os
import re
import librosa
import paramiko
import streamlit as st
from openai import OpenAI
from sshtunnel import SSHTunnelForwarder
local_port = int(os.getenv('LOCAL_PORT'))
class NoAudioException(Exception):
pass
@st.cache_resource()
def start_server():
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
server = SSHTunnelForwarder(
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
ssh_username="ec2-user",
ssh_pkey=pkey,
local_bind_address=("127.0.0.1", local_port),
remote_bind_address=("127.0.0.1", 8000)
)
server.start()
return server
@st.cache_resource()
def load_model():
openai_api_key = os.getenv('API_KEY')
openai_api_base = f"http://localhost:{local_port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_name = models.data[0].id
return client, model_name
def generate_response(text_input):
if not st.session_state.audio_base64:
raise NoAudioException("audio is empty.")
warnings = []
if re.search("tool|code|python|java|math|calculate", text_input):
warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.")
if re.search(r'[\u4e00-\u9fff]+', text_input):
warnings.append("NOTE: Please try to prompt in English for the best performance.")
stream = st.session_state.client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": f"Text instruction: {text_input}"
},
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
},
},
],
}],
model=st.session_state.model_name,
max_completion_tokens=512,
temperature=st.session_state.temperature,
top_p=st.session_state.top_p,
stream=True,
)
return stream, warnings
def bytes_to_array(audio_bytes):
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000
)
return audio_array