MERaLiON-AudioLLM / utils.py
YingxuHe's picture
Update utils.py
45d3e7b verified
raw
history blame
1.68 kB
import io
import librosa
import paramiko
import streamlit as st
from openai import OpenAI
from sshtunnel import SSHTunnelForwarder
local_port = 20000
@st.cache_resource()
def load_model():
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{local_port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_name = models.data[0].id
return client, model_name
def generate_response(text_input, audio_input):
stream = st.session_state.client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": f"Text instruction: {text_input}"
},
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/ogg;base64,{audio_input}"
},
},
],
}],
model=st.session_state.model_name,
max_completion_tokens=512,
stream=True,
)
return stream
def bytes_to_array(audio_bytes):
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000
)
return audio_array
def start_server(ssh_key, dns_name):
pkey = paramiko.RSAKey.from_private_key(io.StringIO(ssh_key))
server = SSHTunnelForwarder(
ssh_address_or_host=dns_name,
ssh_username="ec2-user",
ssh_pkey=pkey,
local_bind_address=("127.0.0.1", local_port),
remote_bind_address=("127.0.0.1", 8000)
)
server.start()