Spaces:
Sleeping
Sleeping
Siddhant
commited on
Commit
•
bf1337a
1
Parent(s):
94b0033
handle browser refresh
Browse files
app.py
CHANGED
@@ -33,6 +33,9 @@ LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruc
|
|
33 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
34 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
35 |
upload_to_hub=None
|
|
|
|
|
|
|
36 |
# def read_args():
|
37 |
# global access_token
|
38 |
# global ASR_name
|
@@ -97,8 +100,13 @@ LLM_response_arr=[]
|
|
97 |
total_response_arr=[]
|
98 |
|
99 |
def handle_selection(option):
|
|
|
|
|
|
|
|
|
100 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
101 |
global text2speech
|
|
|
102 |
tag = option
|
103 |
if tag=="ChatTTS":
|
104 |
text2speech = ChatTTSModel()
|
@@ -108,17 +116,27 @@ def handle_selection(option):
|
|
108 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
109 |
|
110 |
def handle_LLM_selection(option):
|
|
|
|
|
|
|
|
|
111 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
112 |
global LM_pipe
|
|
|
113 |
LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
|
114 |
LM_pipe.warmup()
|
115 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
116 |
|
117 |
def handle_ASR_selection(option):
|
118 |
-
|
119 |
if option=="librispeech_asr":
|
120 |
option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
|
|
|
|
|
|
|
|
|
121 |
global s2t
|
|
|
122 |
if option=="espnet/owsm_v3.1_ebf":
|
123 |
s2t = OWSMModel()
|
124 |
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
@@ -180,14 +198,21 @@ def handle_type_selection(option,TTS_radio,ASR_radio,LLM_radio):
|
|
180 |
text2speech=None
|
181 |
s2t=None
|
182 |
LM_pipe=None
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
handle_E2E_selection()
|
184 |
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
|
185 |
|
186 |
|
187 |
def handle_E2E_selection():
|
188 |
global client
|
189 |
-
client
|
190 |
-
|
|
|
191 |
|
192 |
def start_warmup():
|
193 |
global client
|
@@ -320,7 +345,7 @@ def relevant_vote4_last_response(
|
|
320 |
import json
|
321 |
import time
|
322 |
|
323 |
-
def transcribe(stream, new_chunk,
|
324 |
sr, y = new_chunk
|
325 |
global text_str
|
326 |
global chat
|
@@ -338,6 +363,11 @@ def transcribe(stream, new_chunk, option, asr_option):
|
|
338 |
global LLM_response_arr
|
339 |
global total_response_arr
|
340 |
if stream is None:
|
|
|
|
|
|
|
|
|
|
|
341 |
stream=y
|
342 |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
343 |
text_str=""
|
@@ -530,7 +560,7 @@ with gr.Blocks(
|
|
530 |
diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
|
531 |
ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
|
532 |
callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
|
533 |
-
user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
|
534 |
radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
|
535 |
LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
|
536 |
ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
|
|
|
33 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
34 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
35 |
upload_to_hub=None
|
36 |
+
ASR_curr_name=None
|
37 |
+
LLM_curr_name=None
|
38 |
+
TTS_curr_name=None
|
39 |
# def read_args():
|
40 |
# global access_token
|
41 |
# global ASR_name
|
|
|
100 |
total_response_arr=[]
|
101 |
|
102 |
def handle_selection(option):
|
103 |
+
global TTS_curr_name
|
104 |
+
if TTS_curr_name is not None:
|
105 |
+
if option==TTS_curr_name:
|
106 |
+
return
|
107 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
108 |
global text2speech
|
109 |
+
TTS_curr_name=option
|
110 |
tag = option
|
111 |
if tag=="ChatTTS":
|
112 |
text2speech = ChatTTSModel()
|
|
|
116 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
117 |
|
118 |
def handle_LLM_selection(option):
|
119 |
+
global LLM_curr_name
|
120 |
+
if LLM_curr_name is not None:
|
121 |
+
if option==LLM_curr_name:
|
122 |
+
return
|
123 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
124 |
global LM_pipe
|
125 |
+
LLM_curr_name=option
|
126 |
LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
|
127 |
LM_pipe.warmup()
|
128 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
129 |
|
130 |
def handle_ASR_selection(option):
|
131 |
+
global ASR_curr_name
|
132 |
if option=="librispeech_asr":
|
133 |
option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
|
134 |
+
if ASR_curr_name is not None:
|
135 |
+
if option==ASR_curr_name:
|
136 |
+
return
|
137 |
+
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
138 |
global s2t
|
139 |
+
ASR_curr_name=option
|
140 |
if option=="espnet/owsm_v3.1_ebf":
|
141 |
s2t = OWSMModel()
|
142 |
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
|
|
198 |
text2speech=None
|
199 |
s2t=None
|
200 |
LM_pipe=None
|
201 |
+
global ASR_curr_name
|
202 |
+
global LLM_curr_name
|
203 |
+
global TTS_curr_name
|
204 |
+
ASR_curr_name=None
|
205 |
+
LLM_curr_name=None
|
206 |
+
TTS_curr_name=None
|
207 |
handle_E2E_selection()
|
208 |
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
|
209 |
|
210 |
|
211 |
def handle_E2E_selection():
|
212 |
global client
|
213 |
+
if client is None:
|
214 |
+
client = MiniOmniE2EModel()
|
215 |
+
client.warmup()
|
216 |
|
217 |
def start_warmup():
|
218 |
global client
|
|
|
345 |
import json
|
346 |
import time
|
347 |
|
348 |
+
def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_option):
|
349 |
sr, y = new_chunk
|
350 |
global text_str
|
351 |
global chat
|
|
|
363 |
global LLM_response_arr
|
364 |
global total_response_arr
|
365 |
if stream is None:
|
366 |
+
# Handle user refresh
|
367 |
+
# import pdb;pdb.set_trace()
|
368 |
+
for (_,_,_,_,asr_output_box,text_box,audio_box,_,_) in handle_type_selection(type_option,TTS_option,ASR_option,LLM_option):
|
369 |
+
gr.Info("The models are being reloaded due to a browser refresh.")
|
370 |
+
yield (stream,asr_output_box,text_box,audio_box,gr.Audio(visible=False))
|
371 |
stream=y
|
372 |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
373 |
text_str=""
|
|
|
560 |
diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
|
561 |
ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
|
562 |
callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
|
563 |
+
user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
|
564 |
radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
|
565 |
LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
|
566 |
ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
|