Spaces:
Running
Running
import os | |
import re | |
import time | |
from typing import List, Dict, Optional | |
import numpy as np | |
import streamlit as st | |
from openai import OpenAI, APIConnectionError | |
from src.exceptions import TunnelNotRunningException | |
FIXED_GENERATION_CONFIG = dict( | |
max_completion_tokens=1024, | |
top_k=50, | |
length_penalty=1.0, | |
seed=42 | |
) | |
MAX_AUDIO_LENGTH = 120 | |
def load_model() -> Dict: | |
""" | |
Create an OpenAI client with connection to vllm server. | |
""" | |
openai_api_key = os.getenv('API_KEY') | |
local_ports = os.getenv('LOCAL_PORTS').split(" ") | |
name_to_client_mapper = {} | |
for port in local_ports: | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=f"http://localhost:{port}/v1", | |
) | |
for model in client.models.list().data: | |
name_to_client_mapper[model.id] = client | |
return name_to_client_mapper | |
def prepare_multimodal_content(text_input, base64_audio_input): | |
return [ | |
{ | |
"type": "text", | |
"text": f"Text instruction: {text_input}" | |
}, | |
{ | |
"type": "audio_url", | |
"audio_url": { | |
"url": f"data:audio/ogg;base64,{base64_audio_input}" | |
}, | |
}, | |
] | |
def change_multimodal_content( | |
original_content, | |
text_input="", | |
base64_audio_input=""): | |
# Since python 3.7 dictionary is ordered. | |
if text_input: | |
original_content[0] = { | |
"type": "text", | |
"text": f"Text instruction: {text_input}" | |
} | |
if base64_audio_input: | |
original_content[1] = { | |
"type": "audio_url", | |
"audio_url": { | |
"url": f"data:audio/ogg;base64,{base64_audio_input}" | |
} | |
} | |
return original_content | |
def _retrive_response( | |
model: str, | |
text_input: str, | |
base64_audio_input: str, | |
history: Optional[List] = None, | |
**kwargs): | |
""" | |
Send request through OpenAI client. | |
""" | |
if history is None: | |
history = [] | |
if base64_audio_input: | |
content = [ | |
{ | |
"type": "text", | |
"text": f"Text instruction: {text_input}" | |
}, | |
{ | |
"type": "audio_url", | |
"audio_url": { | |
"url": f"data:audio/ogg;base64,{base64_audio_input}" | |
}, | |
}, | |
] | |
else: | |
content = text_input | |
current_client = st.session_state.client_mapper[model] | |
return current_client.chat.completions.create( | |
messages=history + [{"role": "user", "content": content}], | |
model=model, | |
**kwargs | |
) | |
def _retry_retrive_response_throws_exception(retry=3, **kwargs): | |
try: | |
response_object = _retrive_response(**kwargs) | |
except APIConnectionError as e: | |
if not st.session_state.server.is_running(): | |
if retry == 0: | |
raise TunnelNotRunningException() | |
st.toast(f":warning: Internet connection is down. Trying to re-establish connection ({retry}).") | |
if st.session_state.server.is_down(): | |
st.session_state.server.restart() | |
elif st.session_state.server.is_starting(): | |
time.sleep(2) | |
return _retry_retrive_response_throws_exception(retry-1, **kwargs) | |
raise e | |
return response_object | |
def _validate_input(text_input, array_audio_input) -> List[str]: | |
""" | |
TODO: improve the input validation regex. | |
""" | |
warnings = [] | |
if re.search("tool|code|python|java|math|calculate", text_input): | |
warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.") | |
if re.search(r'[\u4e00-\u9fff]+', text_input): | |
warnings.append("NOTE: Please try to prompt in English for the best performance.") | |
if array_audio_input.shape[0] == 0: | |
warnings.append("NOTE: Please specify audio from examples or local files.") | |
if array_audio_input.shape[0] / 16000 > 30.0: | |
warnings.append(( | |
"WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**." | |
f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated." | |
)) | |
return warnings | |
def retrive_response( | |
text_input: str, | |
array_audio_input: np.ndarray, | |
**kwargs | |
): | |
warnings = _validate_input(text_input, array_audio_input) | |
response_object, error_msg = None, "" | |
try: | |
response_object = _retry_retrive_response_throws_exception( | |
text_input=text_input, | |
**kwargs | |
) | |
except TunnelNotRunningException: | |
error_msg = "Internet connection cannot be established. Please contact the administrator." | |
except Exception as e: | |
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator." | |
return error_msg, warnings, response_object | |
def postprocess_voice_transcription(text): | |
text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text) | |
text = re.sub("\s+", " ", text).strip() | |
return text |