import os import re import time from typing import List, Dict, Optional import numpy as np import streamlit as st from openai import OpenAI, APIConnectionError from src.exceptions import TunnelNotRunningException FIXED_GENERATION_CONFIG = dict( max_completion_tokens=1024, top_k=50, length_penalty=1.0, seed=42 ) MAX_AUDIO_LENGTH = 120 def load_model() -> Dict: """ Create an OpenAI client with connection to vllm server. """ openai_api_key = os.getenv('API_KEY') local_ports = os.getenv('LOCAL_PORTS').split(" ") name_to_client_mapper = {} for port in local_ports: client = OpenAI( api_key=openai_api_key, base_url=f"http://localhost:{port}/v1", ) for model in client.models.list().data: name_to_client_mapper[model.id] = client return name_to_client_mapper def prepare_multimodal_content(text_input, base64_audio_input): return [ { "type": "text", "text": f"Text instruction: {text_input}" }, { "type": "audio_url", "audio_url": { "url": f"data:audio/ogg;base64,{base64_audio_input}" }, }, ] def change_multimodal_content( original_content, text_input="", base64_audio_input=""): # Since python 3.7 dictionary is ordered. if text_input: original_content[0] = { "type": "text", "text": f"Text instruction: {text_input}" } if base64_audio_input: original_content[1] = { "type": "audio_url", "audio_url": { "url": f"data:audio/ogg;base64,{base64_audio_input}" } } return original_content def _retrive_response( model: str, text_input: str, base64_audio_input: str, history: Optional[List] = None, **kwargs): """ Send request through OpenAI client. """ if history is None: history = [] if base64_audio_input: content = [ { "type": "text", "text": f"Text instruction: {text_input}" }, { "type": "audio_url", "audio_url": { "url": f"data:audio/ogg;base64,{base64_audio_input}" }, }, ] else: content = text_input current_client = st.session_state.client_mapper[model] return current_client.chat.completions.create( messages=history + [{"role": "user", "content": content}], model=model, **kwargs ) def _retry_retrive_response_throws_exception(retry=3, **kwargs): try: response_object = _retrive_response(**kwargs) except APIConnectionError as e: if not st.session_state.server.is_running(): if retry == 0: raise TunnelNotRunningException() st.toast(f":warning: Internet connection is down. Trying to re-establish connection ({retry}).") if st.session_state.server.is_down(): st.session_state.server.restart() elif st.session_state.server.is_starting(): time.sleep(2) return _retry_retrive_response_throws_exception(retry-1, **kwargs) raise e return response_object def _validate_input(text_input, array_audio_input) -> List[str]: """ TODO: improve the input validation regex. """ warnings = [] if re.search("tool|code|python|java|math|calculate", text_input): warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.") if re.search(r'[\u4e00-\u9fff]+', text_input): warnings.append("NOTE: Please try to prompt in English for the best performance.") if array_audio_input.shape[0] == 0: warnings.append("NOTE: Please specify audio from examples or local files.") if array_audio_input.shape[0] / 16000 > 30.0: warnings.append(( "WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**." f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated." )) return warnings def retrive_response( text_input: str, array_audio_input: np.ndarray, **kwargs ): warnings = _validate_input(text_input, array_audio_input) response_object, error_msg = None, "" try: response_object = _retry_retrive_response_throws_exception( text_input=text_input, **kwargs ) except TunnelNotRunningException: error_msg = "Internet connection cannot be established. Please contact the administrator." except Exception as e: error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator." return error_msg, warnings, response_object def postprocess_voice_transcription(text): text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text) text = re.sub("\s+", " ", text).strip() return text