Spaces:
Running
Running
import io | |
import os | |
import re | |
import time | |
import librosa | |
import paramiko | |
import streamlit as st | |
from openai import OpenAI, APIConnectionError | |
from sshtunnel import SSHTunnelForwarder | |
local_port = int(os.getenv('LOCAL_PORT')) | |
GENERAL_INSTRUCTIONS = [ | |
"Please transcribe this speech.", | |
"Please summarise this speech." | |
] | |
AUDIO_SAMPLES_W_INSTRUCT = { | |
'1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."], | |
'7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."], | |
'13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."], | |
'17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."], | |
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."], | |
'33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."], | |
'34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."], | |
'35_SQA_IMDA_PART3_30_SQA_V2_4062': ["What is the color of the vase mentioned in the dialogue."], | |
'36_DS_IMDA_PART4_30_DS_V2_849': ["Condense the dialogue into a concise summary highlighting major topics and conclusions."], | |
'39_Paralingual_IEMOCAP_ER_V2_91': ["Based on the speaker's speech patterns, what do you think they are feeling."], | |
'40_Paralingual_IEMOCAP_ER_V2_567': ["Based on the speaker's speech patterns, what do you think they are feeling."], | |
'42_Paralingual_IEMOCAP_GR_V2_320': ["Is it possible for you to identify whether the speaker in this recording is male or female."], | |
'43_Paralingual_IEMOCAP_GR_V2_129': ["Is it possible for you to identify whether the speaker in this recording is male or female."], | |
'45_Paralingual_IMDA_PART3_30_GR_V2_12312': ["So, who's speaking in the second part of the clip?", "So, who's speaking in the first part of the clip?"], | |
'47_Paralingual_IMDA_PART3_30_NR_V2_10479': ["Can you guess which ethnic group this person is from based on their accent."], | |
'49_Paralingual_MELD_ER_V2_676': ["What emotions do you think the speaker is expressing."], | |
'50_Paralingual_MELD_ER_V2_692': ["Based on the speaker's speech patterns, what do you think they are feeling."], | |
'51_Paralingual_VOXCELEB1_GR_V2_2148': ["May I know the gender of the speaker."], | |
'53_Paralingual_VOXCELEB1_NR_V2_2286': ["What's the nationality identity of the speaker."], | |
'55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': ["What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth."], | |
'56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': ["Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore."], | |
'57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': ["How does the author respond to parents' worries about masks in schools."], | |
'2_ASR_IMDA_PART1_ASR_v2_2258': ["Turn the spoken language into a text format.", "Please translate the content into Chinese."], | |
'3_ASR_IMDA_PART1_ASR_v2_2265': ["Turn the spoken language into a text format."], | |
'4_ASR_IMDA_PART2_ASR_v2_999' : ["Translate the spoken words into text format."], | |
'5_ASR_IMDA_PART2_ASR_v2_2241': ["Translate the spoken words into text format."], | |
'6_ASR_IMDA_PART2_ASR_v2_3409': ["Translate the spoken words into text format."], | |
'8_ASR_IMDA_PART3_30_ASR_v2_1698': ["Need this talk written down, please."], | |
'9_ASR_IMDA_PART3_30_ASR_v2_2474': ["Need this talk written down, please."], | |
'11_ASR_IMDA_PART4_30_ASR_v2_3771': ["Write out the dialogue as text."], | |
'12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."], | |
'10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."], | |
'14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."], | |
'15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."], | |
'16_ASR_IMDA_PART6_30_ASR_v2_576': ["Record the spoken word in text form."], | |
'18_ASR_IMDA_PART6_30_ASR_v2_2834': ["Record the spoken word in text form."], | |
'19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."], | |
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."], | |
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."], | |
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."], | |
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."], | |
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."], | |
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."], | |
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."], | |
} | |
class NoAudioException(Exception): | |
pass | |
class TunnelNotRunningException(Exception): | |
pass | |
class SSHTunnelManager: | |
def __init__(self): | |
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY'))) | |
self.server = SSHTunnelForwarder( | |
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'), | |
ssh_username="ec2-user", | |
ssh_pkey=pkey, | |
local_bind_address=("127.0.0.1", local_port), | |
remote_bind_address=("127.0.0.1", 8000) | |
) | |
self._is_starting = False | |
self._is_running = False | |
def update_status(self): | |
if not self._is_starting: | |
self.server.check_tunnels() | |
self._is_running = list(self.server.tunnel_is_up.values())[0] | |
else: | |
self._is_running = False | |
def is_starting(self): | |
self.update_status() | |
return self._is_starting | |
def is_running(self): | |
self.update_status() | |
return self._is_running | |
def is_down(self): | |
self.update_status() | |
return (not self._is_running) and (not self._is_starting) | |
def start(self, *args, **kwargs): | |
if not self._is_starting: | |
self._is_starting = True | |
self.server.start(*args, **kwargs) | |
self._is_starting = False | |
def restart(self, *args, **kwargs): | |
if not self._is_starting: | |
self._is_starting = True | |
self.server.restart(*args, **kwargs) | |
self._is_starting = False | |
def start_server(): | |
server = SSHTunnelManager() | |
server.start() | |
return server | |
def load_model(): | |
openai_api_key = os.getenv('API_KEY') | |
openai_api_base = f"http://localhost:{local_port}/v1" | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
models = client.models.list() | |
model_name = models.data[0].id | |
return client, model_name | |
def generate_response(text_input): | |
if not st.session_state.audio_base64: | |
raise NoAudioException("audio is empty.") | |
warnings = [] | |
if re.search("tool|code|python|java|math|calculate", text_input): | |
warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.") | |
if re.search(r'[\u4e00-\u9fff]+', text_input): | |
warnings.append("NOTE: Please try to prompt in English for the best performance.") | |
try: | |
stream = st.session_state.client.chat.completions.create( | |
messages=[{ | |
"role": | |
"user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": f"Text instruction: {text_input}" | |
}, | |
{ | |
"type": "audio_url", | |
"audio_url": { | |
"url": f"data:audio/ogg;base64,{st.session_state.audio_base64}" | |
}, | |
}, | |
], | |
}], | |
model=st.session_state.model_name, | |
max_completion_tokens=512, | |
temperature=st.session_state.temperature, | |
top_p=st.session_state.top_p, | |
stream=True, | |
) | |
except APIConnectionError as e: | |
if not st.session_state.server.is_running(): | |
raise TunnelNotRunningException() | |
raise e | |
return stream, warnings | |
def retry_generate_response(prompt, retry=3): | |
response, warnings = "", [] | |
try: | |
stream, warnings = generate_response(prompt) | |
for warning_msg in warnings: | |
st.warning(warning_msg) | |
response = st.write_stream(stream) | |
except TunnelNotRunningException as e: | |
if retry == 0: | |
raise e | |
st.warning(f"Internet connection is down. Trying to re-establish connection ({retry}).") | |
if st.session_state.server.is_down(): | |
st.session_state.server.restart() | |
elif st.session_state.server.is_starting(): | |
time.sleep(2) | |
return retry_generate_response(retry-1) | |
return response, warnings | |
def bytes_to_array(audio_bytes): | |
audio_array, _ = librosa.load( | |
io.BytesIO(audio_bytes), | |
sr=16000 | |
) | |
return audio_array |