MERaLiON-AudioLLM / utils.py
YingxuHe's picture
Update utils.py
6994ab8 verified
raw
history blame
1.85 kB
import os
import io
import librosa
import paramiko
import streamlit as st
from openai import OpenAI
from sshtunnel import SSHTunnelForwarder
local_port = int(os.getenv('LOCAL_PORT'))
@st.cache_resource()
def load_model():
openai_api_key = os.getenv('API_KEY')
openai_api_base = f"http://localhost:{local_port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_name = models.data[0].id
return client, model_name
def generate_response(text_input, audio_input):
stream = st.session_state.client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": f"Text instruction: {text_input}"
},
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/ogg;base64,{audio_input}"
},
},
],
}],
model=st.session_state.model_name,
max_completion_tokens=512,
stream=True,
)
return stream
def bytes_to_array(audio_bytes):
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000
)
return audio_array
def start_server():
if os.getenv('TUNNEL_STATUS') == "200":
return
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
server = SSHTunnelForwarder(
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
ssh_username="ec2-user",
ssh_pkey=pkey,
local_bind_address=("127.0.0.1", local_port),
remote_bind_address=("127.0.0.1", 8000)
)
os.environ['TUNNEL_STATUS'] = "200"
server.start()