Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import io
|
2 |
import os
|
3 |
import re
|
|
|
4 |
|
5 |
import librosa
|
6 |
import paramiko
|
7 |
import streamlit as st
|
8 |
-
from openai import OpenAI
|
9 |
from sshtunnel import SSHTunnelForwarder
|
10 |
|
11 |
local_port = int(os.getenv('LOCAL_PORT'))
|
@@ -20,12 +21,9 @@ GENERAL_INSTRUCTIONS = [
|
|
20 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
21 |
'1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
|
22 |
'7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
|
|
|
23 |
'17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
|
24 |
|
25 |
-
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
|
26 |
-
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
|
27 |
-
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
|
28 |
-
|
29 |
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
|
30 |
'33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
|
31 |
'34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
|
@@ -61,7 +59,6 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
61 |
'12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
|
62 |
'10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
|
63 |
|
64 |
-
'13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
|
65 |
'14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
|
66 |
'15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
|
67 |
|
@@ -71,9 +68,13 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
71 |
'19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
|
72 |
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
|
73 |
|
|
|
|
|
|
|
74 |
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
|
75 |
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
|
76 |
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
|
|
|
77 |
}
|
78 |
|
79 |
|
@@ -81,22 +82,64 @@ class NoAudioException(Exception):
|
|
81 |
pass
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
@st.cache_resource()
|
85 |
def start_server():
|
86 |
-
|
87 |
-
|
88 |
-
server = SSHTunnelForwarder(
|
89 |
-
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
|
90 |
-
ssh_username="ec2-user",
|
91 |
-
ssh_pkey=pkey,
|
92 |
-
local_bind_address=("127.0.0.1", local_port),
|
93 |
-
remote_bind_address=("127.0.0.1", 8000)
|
94 |
-
)
|
95 |
server.start()
|
96 |
return server
|
97 |
|
98 |
-
|
99 |
-
@st.cache_resource()
|
100 |
def load_model():
|
101 |
openai_api_key = os.getenv('API_KEY')
|
102 |
openai_api_base = f"http://localhost:{local_port}/v1"
|
@@ -122,34 +165,63 @@ def generate_response(text_input):
|
|
122 |
|
123 |
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
124 |
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
{
|
136 |
-
"type": "audio_url",
|
137 |
-
"audio_url": {
|
138 |
-
"url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
|
139 |
},
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
return stream, warnings
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def bytes_to_array(audio_bytes):
|
154 |
audio_array, _ = librosa.load(
|
155 |
io.BytesIO(audio_bytes),
|
|
|
1 |
import io
|
2 |
import os
|
3 |
import re
|
4 |
+
import time
|
5 |
|
6 |
import librosa
|
7 |
import paramiko
|
8 |
import streamlit as st
|
9 |
+
from openai import OpenAI, APIConnectionError
|
10 |
from sshtunnel import SSHTunnelForwarder
|
11 |
|
12 |
local_port = int(os.getenv('LOCAL_PORT'))
|
|
|
21 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
22 |
'1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
|
23 |
'7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
|
24 |
+
'13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
|
25 |
'17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
|
26 |
|
|
|
|
|
|
|
|
|
27 |
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
|
28 |
'33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
|
29 |
'34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
|
|
|
59 |
'12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
|
60 |
'10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
|
61 |
|
|
|
62 |
'14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
|
63 |
'15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
|
64 |
|
|
|
68 |
'19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
|
69 |
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
|
70 |
|
71 |
+
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
|
72 |
+
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
|
73 |
+
|
74 |
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
|
75 |
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
|
76 |
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
|
77 |
+
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
|
78 |
}
|
79 |
|
80 |
|
|
|
82 |
pass
|
83 |
|
84 |
|
85 |
+
class TunnelNotRunningException(Exception):
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
class SSHTunnelManager:
|
90 |
+
def __init__(self):
|
91 |
+
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
|
92 |
+
|
93 |
+
self.server = SSHTunnelForwarder(
|
94 |
+
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
|
95 |
+
ssh_username="ec2-user",
|
96 |
+
ssh_pkey=pkey,
|
97 |
+
local_bind_address=("127.0.0.1", local_port),
|
98 |
+
remote_bind_address=("127.0.0.1", 8000)
|
99 |
+
)
|
100 |
+
|
101 |
+
self._is_starting = False
|
102 |
+
self._is_running = False
|
103 |
+
|
104 |
+
def update_status(self):
|
105 |
+
if not self._is_starting:
|
106 |
+
self.server.check_tunnels()
|
107 |
+
self._is_running = list(self.server.tunnel_is_up.values())[0]
|
108 |
+
else:
|
109 |
+
self._is_running = False
|
110 |
+
|
111 |
+
def is_starting(self):
|
112 |
+
self.update_status()
|
113 |
+
return self._is_starting
|
114 |
+
|
115 |
+
def is_running(self):
|
116 |
+
self.update_status()
|
117 |
+
return self._is_running
|
118 |
+
|
119 |
+
def is_down(self):
|
120 |
+
self.update_status()
|
121 |
+
return (not self._is_running) and (not self._is_starting)
|
122 |
+
|
123 |
+
def start(self, *args, **kwargs):
|
124 |
+
if not self._is_starting:
|
125 |
+
self._is_starting = True
|
126 |
+
self.server.start(*args, **kwargs)
|
127 |
+
self._is_starting = False
|
128 |
+
|
129 |
+
def restart(self, *args, **kwargs):
|
130 |
+
if not self._is_starting:
|
131 |
+
self._is_starting = True
|
132 |
+
self.server.restart(*args, **kwargs)
|
133 |
+
self._is_starting = False
|
134 |
+
|
135 |
+
|
136 |
@st.cache_resource()
|
137 |
def start_server():
|
138 |
+
server = SSHTunnelManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
server.start()
|
140 |
return server
|
141 |
|
142 |
+
|
|
|
143 |
def load_model():
|
144 |
openai_api_key = os.getenv('API_KEY')
|
145 |
openai_api_base = f"http://localhost:{local_port}/v1"
|
|
|
165 |
|
166 |
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
167 |
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
168 |
+
|
169 |
+
try:
|
170 |
+
stream = st.session_state.client.chat.completions.create(
|
171 |
+
messages=[{
|
172 |
+
"role":
|
173 |
+
"user",
|
174 |
+
"content": [
|
175 |
+
{
|
176 |
+
"type": "text",
|
177 |
+
"text": f"Text instruction: {text_input}"
|
|
|
|
|
|
|
|
|
178 |
},
|
179 |
+
{
|
180 |
+
"type": "audio_url",
|
181 |
+
"audio_url": {
|
182 |
+
"url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
|
183 |
+
},
|
184 |
+
},
|
185 |
+
],
|
186 |
+
}],
|
187 |
+
model=st.session_state.model_name,
|
188 |
+
max_completion_tokens=512,
|
189 |
+
temperature=st.session_state.temperature,
|
190 |
+
top_p=st.session_state.top_p,
|
191 |
+
stream=True,
|
192 |
+
)
|
193 |
+
except APIConnectionError as e:
|
194 |
+
if not st.session_state.server.is_running():
|
195 |
+
raise TunnelNotRunningException()
|
196 |
+
raise e
|
197 |
|
198 |
return stream, warnings
|
199 |
|
200 |
|
201 |
+
def retry_generate_response(retry=3):
|
202 |
+
response, warnings = "", []
|
203 |
+
|
204 |
+
try:
|
205 |
+
stream, warnings = generate_response(st.session_state.prompt)
|
206 |
+
for warning_msg in warnings:
|
207 |
+
st.warning(warning_msg)
|
208 |
+
response = st.write_stream(stream)
|
209 |
+
except TunnelNotRunningException as e:
|
210 |
+
if retry == 0:
|
211 |
+
raise e
|
212 |
+
|
213 |
+
st.error(f"Internet connection is down. Trying to re-establish connection ({retry}).")
|
214 |
+
|
215 |
+
if st.session_state.server.is_down():
|
216 |
+
st.session_state.server.restart()
|
217 |
+
elif st.session_state.server.is_starting():
|
218 |
+
time.sleep(2)
|
219 |
+
|
220 |
+
return retry_generate_response(retry-1)
|
221 |
+
|
222 |
+
return response, warnings
|
223 |
+
|
224 |
+
|
225 |
def bytes_to_array(audio_bytes):
|
226 |
audio_array, _ = librosa.load(
|
227 |
io.BytesIO(audio_bytes),
|