Siddhant commited on
Commit
bf1337a
·
1 Parent(s): 94b0033

handle browser refresh

Browse files
Files changed (1) hide show
  1. app.py +35 -5
app.py CHANGED
@@ -33,6 +33,9 @@ LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruc
33
  TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
34
  Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
35
  upload_to_hub=None
 
 
 
36
  # def read_args():
37
  # global access_token
38
  # global ASR_name
@@ -97,8 +100,13 @@ LLM_response_arr=[]
97
  total_response_arr=[]
98
 
99
  def handle_selection(option):
 
 
 
 
100
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
101
  global text2speech
 
102
  tag = option
103
  if tag=="ChatTTS":
104
  text2speech = ChatTTSModel()
@@ -108,17 +116,27 @@ def handle_selection(option):
108
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
109
 
110
  def handle_LLM_selection(option):
 
 
 
 
111
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
112
  global LM_pipe
 
113
  LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
114
  LM_pipe.warmup()
115
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
116
 
117
  def handle_ASR_selection(option):
118
- yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
119
  if option=="librispeech_asr":
120
  option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
 
 
 
 
121
  global s2t
 
122
  if option=="espnet/owsm_v3.1_ebf":
123
  s2t = OWSMModel()
124
  elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
@@ -180,14 +198,21 @@ def handle_type_selection(option,TTS_radio,ASR_radio,LLM_radio):
180
  text2speech=None
181
  s2t=None
182
  LM_pipe=None
 
 
 
 
 
 
183
  handle_E2E_selection()
184
  yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
185
 
186
 
187
  def handle_E2E_selection():
188
  global client
189
- client = MiniOmniE2EModel()
190
- client.warmup()
 
191
 
192
  def start_warmup():
193
  global client
@@ -320,7 +345,7 @@ def relevant_vote4_last_response(
320
  import json
321
  import time
322
 
323
- def transcribe(stream, new_chunk, option, asr_option):
324
  sr, y = new_chunk
325
  global text_str
326
  global chat
@@ -338,6 +363,11 @@ def transcribe(stream, new_chunk, option, asr_option):
338
  global LLM_response_arr
339
  global total_response_arr
340
  if stream is None:
 
 
 
 
 
341
  stream=y
342
  chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
343
  text_str=""
@@ -530,7 +560,7 @@ with gr.Blocks(
530
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
531
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
532
  callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
533
- user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
534
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
535
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
536
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
 
33
  TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
34
  Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
35
  upload_to_hub=None
36
+ ASR_curr_name=None
37
+ LLM_curr_name=None
38
+ TTS_curr_name=None
39
  # def read_args():
40
  # global access_token
41
  # global ASR_name
 
100
  total_response_arr=[]
101
 
102
  def handle_selection(option):
103
+ global TTS_curr_name
104
+ if TTS_curr_name is not None:
105
+ if option==TTS_curr_name:
106
+ return
107
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
108
  global text2speech
109
+ TTS_curr_name=option
110
  tag = option
111
  if tag=="ChatTTS":
112
  text2speech = ChatTTSModel()
 
116
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
117
 
118
  def handle_LLM_selection(option):
119
+ global LLM_curr_name
120
+ if LLM_curr_name is not None:
121
+ if option==LLM_curr_name:
122
+ return
123
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
124
  global LM_pipe
125
+ LLM_curr_name=option
126
  LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
127
  LM_pipe.warmup()
128
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
129
 
130
  def handle_ASR_selection(option):
131
+ global ASR_curr_name
132
  if option=="librispeech_asr":
133
  option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
134
+ if ASR_curr_name is not None:
135
+ if option==ASR_curr_name:
136
+ return
137
+ yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
138
  global s2t
139
+ ASR_curr_name=option
140
  if option=="espnet/owsm_v3.1_ebf":
141
  s2t = OWSMModel()
142
  elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
 
198
  text2speech=None
199
  s2t=None
200
  LM_pipe=None
201
+ global ASR_curr_name
202
+ global LLM_curr_name
203
+ global TTS_curr_name
204
+ ASR_curr_name=None
205
+ LLM_curr_name=None
206
+ TTS_curr_name=None
207
  handle_E2E_selection()
208
  yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
209
 
210
 
211
  def handle_E2E_selection():
212
  global client
213
+ if client is None:
214
+ client = MiniOmniE2EModel()
215
+ client.warmup()
216
 
217
  def start_warmup():
218
  global client
 
345
  import json
346
  import time
347
 
348
+ def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_option):
349
  sr, y = new_chunk
350
  global text_str
351
  global chat
 
363
  global LLM_response_arr
364
  global total_response_arr
365
  if stream is None:
366
+ # Handle user refresh
367
+ # import pdb;pdb.set_trace()
368
+ for (_,_,_,_,asr_output_box,text_box,audio_box,_,_) in handle_type_selection(type_option,TTS_option,ASR_option,LLM_option):
369
+ gr.Info("The models are being reloaded due to a browser refresh.")
370
+ yield (stream,asr_output_box,text_box,audio_box,gr.Audio(visible=False))
371
  stream=y
372
  chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
373
  text_str=""
 
560
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
561
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
562
  callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
563
+ user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
564
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
565
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
566
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])