MERaLiON-AudioLLM / utils.py
YingxuHe's picture
Update utils.py
abf7a31 verified
raw
history blame
1.9 kB
import os
import io
import librosa
import paramiko
import streamlit as st
from openai import OpenAI
from sshtunnel import SSHTunnelForwarder
local_port = int(os.getenv('LOCAL_PORT'))
@st.cache_resource()
def start_server():
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
server = SSHTunnelForwarder(
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
ssh_username="ec2-user",
ssh_pkey=pkey,
local_bind_address=("127.0.0.1", local_port),
remote_bind_address=("127.0.0.1", 8000)
)
server.start()
return server
@st.cache_resource()
def load_model():
openai_api_key = os.getenv('API_KEY')
openai_api_base = f"http://localhost:{local_port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_name = models.data[0].id
return client, model_name
def generate_response(text_input):
stream = st.session_state.client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": f"Text instruction: {text_input}"
},
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
},
},
],
}],
model=st.session_state.model_name,
max_completion_tokens=512,
temperature=st.session_state.temperature,
top_p=st.session_state.top_p,
stream=True,
)
return stream
def bytes_to_array(audio_bytes):
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000
)
return audio_array