YingxuHe commited on
Commit
8a5a187
·
verified ·
1 Parent(s): 4bd0ef1

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +112 -40
utils.py CHANGED
@@ -1,11 +1,12 @@
1
  import io
2
  import os
3
  import re
 
4
 
5
  import librosa
6
  import paramiko
7
  import streamlit as st
8
- from openai import OpenAI
9
  from sshtunnel import SSHTunnelForwarder
10
 
11
  local_port = int(os.getenv('LOCAL_PORT'))
@@ -20,12 +21,9 @@ GENERAL_INSTRUCTIONS = [
20
  AUDIO_SAMPLES_W_INSTRUCT = {
21
  '1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
22
  '7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
 
23
  '17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
24
 
25
- '25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
26
- '26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
27
- '30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
28
-
29
  '32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
30
  '33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
31
  '34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
@@ -61,7 +59,6 @@ AUDIO_SAMPLES_W_INSTRUCT = {
61
  '12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
62
  '10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
63
 
64
- '13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
65
  '14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
66
  '15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
67
 
@@ -71,9 +68,13 @@ AUDIO_SAMPLES_W_INSTRUCT = {
71
  '19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
72
  '20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
73
 
 
 
 
74
  '27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
75
  '28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
76
  '29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
 
77
  }
78
 
79
 
@@ -81,22 +82,64 @@ class NoAudioException(Exception):
81
  pass
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @st.cache_resource()
85
  def start_server():
86
- pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
87
-
88
- server = SSHTunnelForwarder(
89
- ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
90
- ssh_username="ec2-user",
91
- ssh_pkey=pkey,
92
- local_bind_address=("127.0.0.1", local_port),
93
- remote_bind_address=("127.0.0.1", 8000)
94
- )
95
  server.start()
96
  return server
97
 
98
-
99
- @st.cache_resource()
100
  def load_model():
101
  openai_api_key = os.getenv('API_KEY')
102
  openai_api_base = f"http://localhost:{local_port}/v1"
@@ -122,34 +165,63 @@ def generate_response(text_input):
122
 
123
  if re.search(r'[\u4e00-\u9fff]+', text_input):
124
  warnings.append("NOTE: Please try to prompt in English for the best performance.")
125
-
126
- stream = st.session_state.client.chat.completions.create(
127
- messages=[{
128
- "role":
129
- "user",
130
- "content": [
131
- {
132
- "type": "text",
133
- "text": f"Text instruction: {text_input}"
134
- },
135
- {
136
- "type": "audio_url",
137
- "audio_url": {
138
- "url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
139
  },
140
- },
141
- ],
142
- }],
143
- model=st.session_state.model_name,
144
- max_completion_tokens=512,
145
- temperature=st.session_state.temperature,
146
- top_p=st.session_state.top_p,
147
- stream=True,
148
- )
 
 
 
 
 
 
 
 
 
149
 
150
  return stream, warnings
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def bytes_to_array(audio_bytes):
154
  audio_array, _ = librosa.load(
155
  io.BytesIO(audio_bytes),
 
1
  import io
2
  import os
3
  import re
4
+ import time
5
 
6
  import librosa
7
  import paramiko
8
  import streamlit as st
9
+ from openai import OpenAI, APIConnectionError
10
  from sshtunnel import SSHTunnelForwarder
11
 
12
  local_port = int(os.getenv('LOCAL_PORT'))
 
21
  AUDIO_SAMPLES_W_INSTRUCT = {
22
  '1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
23
  '7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
24
+ '13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
25
  '17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
26
 
 
 
 
 
27
  '32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
28
  '33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
29
  '34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
 
59
  '12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
60
  '10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
61
 
 
62
  '14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
63
  '15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
64
 
 
68
  '19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
69
  '20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
70
 
71
+ '25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
72
+ '26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
73
+
74
  '27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
75
  '28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
76
  '29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
77
+ '30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
78
  }
79
 
80
 
 
82
  pass
83
 
84
 
85
+ class TunnelNotRunningException(Exception):
86
+ pass
87
+
88
+
89
+ class SSHTunnelManager:
90
+ def __init__(self):
91
+ pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
92
+
93
+ self.server = SSHTunnelForwarder(
94
+ ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
95
+ ssh_username="ec2-user",
96
+ ssh_pkey=pkey,
97
+ local_bind_address=("127.0.0.1", local_port),
98
+ remote_bind_address=("127.0.0.1", 8000)
99
+ )
100
+
101
+ self._is_starting = False
102
+ self._is_running = False
103
+
104
+ def update_status(self):
105
+ if not self._is_starting:
106
+ self.server.check_tunnels()
107
+ self._is_running = list(self.server.tunnel_is_up.values())[0]
108
+ else:
109
+ self._is_running = False
110
+
111
+ def is_starting(self):
112
+ self.update_status()
113
+ return self._is_starting
114
+
115
+ def is_running(self):
116
+ self.update_status()
117
+ return self._is_running
118
+
119
+ def is_down(self):
120
+ self.update_status()
121
+ return (not self._is_running) and (not self._is_starting)
122
+
123
+ def start(self, *args, **kwargs):
124
+ if not self._is_starting:
125
+ self._is_starting = True
126
+ self.server.start(*args, **kwargs)
127
+ self._is_starting = False
128
+
129
+ def restart(self, *args, **kwargs):
130
+ if not self._is_starting:
131
+ self._is_starting = True
132
+ self.server.restart(*args, **kwargs)
133
+ self._is_starting = False
134
+
135
+
136
  @st.cache_resource()
137
  def start_server():
138
+ server = SSHTunnelManager()
 
 
 
 
 
 
 
 
139
  server.start()
140
  return server
141
 
142
+
 
143
  def load_model():
144
  openai_api_key = os.getenv('API_KEY')
145
  openai_api_base = f"http://localhost:{local_port}/v1"
 
165
 
166
  if re.search(r'[\u4e00-\u9fff]+', text_input):
167
  warnings.append("NOTE: Please try to prompt in English for the best performance.")
168
+
169
+ try:
170
+ stream = st.session_state.client.chat.completions.create(
171
+ messages=[{
172
+ "role":
173
+ "user",
174
+ "content": [
175
+ {
176
+ "type": "text",
177
+ "text": f"Text instruction: {text_input}"
 
 
 
 
178
  },
179
+ {
180
+ "type": "audio_url",
181
+ "audio_url": {
182
+ "url": f"data:audio/ogg;base64,{st.session_state.audio_base64}"
183
+ },
184
+ },
185
+ ],
186
+ }],
187
+ model=st.session_state.model_name,
188
+ max_completion_tokens=512,
189
+ temperature=st.session_state.temperature,
190
+ top_p=st.session_state.top_p,
191
+ stream=True,
192
+ )
193
+ except APIConnectionError as e:
194
+ if not st.session_state.server.is_running():
195
+ raise TunnelNotRunningException()
196
+ raise e
197
 
198
  return stream, warnings
199
 
200
 
201
+ def retry_generate_response(retry=3):
202
+ response, warnings = "", []
203
+
204
+ try:
205
+ stream, warnings = generate_response(st.session_state.prompt)
206
+ for warning_msg in warnings:
207
+ st.warning(warning_msg)
208
+ response = st.write_stream(stream)
209
+ except TunnelNotRunningException as e:
210
+ if retry == 0:
211
+ raise e
212
+
213
+ st.error(f"Internet connection is down. Trying to re-establish connection ({retry}).")
214
+
215
+ if st.session_state.server.is_down():
216
+ st.session_state.server.restart()
217
+ elif st.session_state.server.is_starting():
218
+ time.sleep(2)
219
+
220
+ return retry_generate_response(retry-1)
221
+
222
+ return response, warnings
223
+
224
+
225
  def bytes_to_array(audio_bytes):
226
  audio_array, _ = librosa.load(
227
  io.BytesIO(audio_bytes),