Siddhant commited on
Commit
1e42459
·
verified ·
1 Parent(s): 0c97eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -34
app.py CHANGED
@@ -20,6 +20,7 @@ from espnet2.sds.eval.ASR_WER import handle_espnet_ASR_WER
20
  from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
21
  from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
22
  from espnet2.sds.utils.chat import Chat
 
23
  import argparse
24
 
25
  access_token = os.environ.get("HF_TOKEN")
@@ -80,6 +81,7 @@ user_role = "user"
80
  text2speech=None
81
  s2t=None
82
  LM_pipe=None
 
83
 
84
  latency_ASR=0.0
85
  latency_LM=0.0
@@ -144,6 +146,48 @@ def handle_eval_selection(option, TTS_audio_output, LLM_Output, ASR_audio_output
144
  elif option=="Text Dialog Metrics":
145
  yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  for _ in handle_selection(TTS_name):
148
  continue
149
  for _ in handle_ASR_selection(ASR_name):
@@ -270,36 +314,44 @@ def transcribe(stream, new_chunk, option, asr_option):
270
  stream=np.concatenate((stream,y))
271
  orig_sr=sr
272
  sr=16000
273
- array=vad_model(y,orig_sr)
 
 
 
274
 
275
  if array is not None:
276
  print("VAD: end of speech detected")
277
  start_time = time.time()
278
- prompt=s2t(array)
279
- if len(prompt.strip().split())<2:
280
- text_str1=text_str
281
- yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
282
- return
283
-
284
-
285
- asr_output_str=prompt
286
- total_response_arr.append(prompt.replace("\n"," "))
287
- start_LM_time=time.time()
288
- latency_ASR=(start_LM_time - start_time)
289
- chat.append({"role": user_role, "content": prompt})
290
- chat_messages = chat.to_list()
291
- generated_text = LM_pipe(chat_messages)
292
- start_TTS_time=time.time()
293
- latency_LM=(start_TTS_time - start_LM_time)
294
-
295
- chat.append({"role": "assistant", "content": generated_text})
296
- text_str=generated_text
297
- LLM_response_arr.append(text_str.replace("\n"," "))
298
- total_response_arr.append(text_str.replace("\n"," "))
299
- audio_output=text2speech(text_str)
 
 
 
 
300
  audio_output1=(orig_sr,stream)
301
  stream=y
302
- latency_TTS=(time.time() - start_TTS_time)
 
303
  text_str1=text_str
304
  if ((text_str!="") and (start_record_time is None)):
305
  start_record_time=time.time()
@@ -338,6 +390,12 @@ with gr.Blocks(
338
  with gr.Row():
339
  with gr.Column(scale=1):
340
  user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
 
 
 
 
 
 
341
  with gr.Row():
342
  ASR_radio = gr.Radio(
343
  choices=ASR_options,
@@ -356,6 +414,13 @@ with gr.Blocks(
356
  label="Choose TTS:",
357
  value=TTS_name,
358
  )
 
 
 
 
 
 
 
359
  with gr.Row():
360
  feedback_btn = gr.Button(
361
  value="Please provide your feedback after each system response below.", visible=True, interactive=False, elem_id="button"
@@ -391,6 +456,11 @@ with gr.Blocks(
391
  choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
392
  label="Choose Evaluation metrics:",
393
  )
 
 
 
 
 
394
  output_eval_text = gr.Textbox(label="Evaluation Results")
395
  state = gr.State()
396
  with gr.Row():
@@ -421,22 +491,24 @@ with gr.Blocks(
421
  natural_response = gr.Textbox(label="natural_response",visible=False,interactive=False)
422
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
423
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
424
- callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address],"flagged_data_points")
425
  user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
426
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
427
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
428
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
429
  eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
 
 
430
  output_audio.play(
431
  flash_buttons, [], [natural_response,diversity_response]+btn_list
432
- ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio], None,preprocess=False)
433
- natural_btn1.click(natural_vote1_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
434
- natural_btn2.click(natural_vote2_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
435
- natural_btn3.click(natural_vote3_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
436
- natural_btn4.click(natural_vote4_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
437
- relevant_btn1.click(relevant_vote1_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
438
- relevant_btn2.click(relevant_vote2_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
439
- relevant_btn3.click(relevant_vote3_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
440
- relevant_btn4.click(relevant_vote4_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
441
 
442
  demo.launch(share=True)
 
20
  from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
21
  from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
22
  from espnet2.sds.utils.chat import Chat
23
+ from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
24
  import argparse
25
 
26
  access_token = os.environ.get("HF_TOKEN")
 
81
  text2speech=None
82
  s2t=None
83
  LM_pipe=None
84
+ client=None
85
 
86
  latency_ASR=0.0
87
  latency_LM=0.0
 
146
  elif option=="Text Dialog Metrics":
147
  yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
148
 
149
+ def handle_eval_selection_E2E(option, TTS_audio_output, LLM_Output):
150
+ global LLM_response_arr
151
+ global total_response_arr
152
+ yield (option,gr.Textbox(visible=True))
153
+ if option=="Latency":
154
+ text=f"Total Latency: {latency_TTS:.2f}"
155
+ yield (None,text)
156
+ elif option=="TTS Intelligibility":
157
+ yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
158
+ elif option=="TTS Speech Quality":
159
+ yield (None,TTS_psuedomos(TTS_audio_output))
160
+ elif option=="Text Dialog Metrics":
161
+ yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr))
162
+
163
+ def handle_type_selection(option,TTS_radio,ASR_radio,LLM_radio):
164
+ global client
165
+ global LM_pipe
166
+ global s2t
167
+ global text2speech
168
+ yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False), gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False))
169
+ if option=="Cascaded":
170
+ client=None
171
+ for _ in handle_selection(TTS_radio):
172
+ continue
173
+ for _ in handle_ASR_selection(ASR_radio):
174
+ continue
175
+ for _ in handle_LLM_selection(LLM_radio):
176
+ continue
177
+ yield (gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=False),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=True, interactive=True),gr.Radio(visible=False))
178
+ else:
179
+ text2speech=None
180
+ s2t=None
181
+ LM_pipe=None
182
+ handle_E2E_selection()
183
+ yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
184
+
185
+
186
+ def handle_E2E_selection():
187
+ global client
188
+ client = MiniOmniE2EModel()
189
+ client.warmup()
190
+
191
  for _ in handle_selection(TTS_name):
192
  continue
193
  for _ in handle_ASR_selection(ASR_name):
 
314
  stream=np.concatenate((stream,y))
315
  orig_sr=sr
316
  sr=16000
317
+ if client is not None:
318
+ array=vad_model(y,orig_sr, binary=True)
319
+ else:
320
+ array=vad_model(y,orig_sr)
321
 
322
  if array is not None:
323
  print("VAD: end of speech detected")
324
  start_time = time.time()
325
+ if client is not None:
326
+ (text_str, audio_output)=client(array, orig_sr)
327
+ asr_output_str=""
328
+ latency_TTS=(time.time() - start_time)
329
+ else:
330
+ prompt=s2t(array)
331
+ if len(prompt.strip().split())<2:
332
+ text_str1=text_str
333
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
334
+ return
335
+
336
+
337
+ asr_output_str=prompt
338
+ total_response_arr.append(prompt.replace("\n"," "))
339
+ start_LM_time=time.time()
340
+ latency_ASR=(start_LM_time - start_time)
341
+ chat.append({"role": user_role, "content": prompt})
342
+ chat_messages = chat.to_list()
343
+ generated_text = LM_pipe(chat_messages)
344
+ start_TTS_time=time.time()
345
+ latency_LM=(start_TTS_time - start_LM_time)
346
+
347
+ chat.append({"role": "assistant", "content": generated_text})
348
+ text_str=generated_text
349
+ audio_output=text2speech(text_str)
350
+ latency_TTS=(time.time() - start_TTS_time)
351
  audio_output1=(orig_sr,stream)
352
  stream=y
353
+ LLM_response_arr.append(text_str.replace("\n"," "))
354
+ total_response_arr.append(text_str.replace("\n"," "))
355
  text_str1=text_str
356
  if ((text_str!="") and (start_record_time is None)):
357
  start_record_time=time.time()
 
390
  with gr.Row():
391
  with gr.Column(scale=1):
392
  user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
393
+ with gr.Row():
394
+ type_radio = gr.Radio(
395
+ choices=["Cascaded", "E2E"],
396
+ label="Choose type of Spoken Dialog:",
397
+ value="Cascaded",
398
+ )
399
  with gr.Row():
400
  ASR_radio = gr.Radio(
401
  choices=ASR_options,
 
414
  label="Choose TTS:",
415
  value=TTS_name,
416
  )
417
+ with gr.Row():
418
+ E2Eradio = gr.Radio(
419
+ choices=["mini-omni"],
420
+ label="Choose E2E model:",
421
+ value="mini-omni",
422
+ visible=False,
423
+ )
424
  with gr.Row():
425
  feedback_btn = gr.Button(
426
  value="Please provide your feedback after each system response below.", visible=True, interactive=False, elem_id="button"
 
456
  choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
457
  label="Choose Evaluation metrics:",
458
  )
459
+ eval_radio_E2E = gr.Radio(
460
+ choices=["Latency", "TTS Intelligibility", "TTS Speech Quality","Text Dialog Metrics"],
461
+ label="Choose Evaluation metrics:",
462
+ visible=False,
463
+ )
464
  output_eval_text = gr.Textbox(label="Evaluation Results")
465
  state = gr.State()
466
  with gr.Row():
 
491
  natural_response = gr.Textbox(label="natural_response",visible=False,interactive=False)
492
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
493
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
494
+ callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
495
  user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
496
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
497
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
498
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
499
  eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
500
+ eval_radio_E2E.change(fn=handle_eval_selection_E2E, inputs=[eval_radio_E2E,output_audio,output_text], outputs=[eval_radio_E2E,output_eval_text])
501
+ type_radio.change(fn=handle_type_selection,inputs=[type_radio,radio,ASR_radio,LLM_radio], outputs=[radio,ASR_radio,LLM_radio, E2Eradio,output_asr_text, output_text, output_audio,eval_radio,eval_radio_E2E])
502
  output_audio.play(
503
  flash_buttons, [], [natural_response,diversity_response]+btn_list
504
+ ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio], None,preprocess=False)
505
+ natural_btn1.click(natural_vote1_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
506
+ natural_btn2.click(natural_vote2_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
507
+ natural_btn3.click(natural_vote3_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
508
+ natural_btn4.click(natural_vote4_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
509
+ relevant_btn1.click(relevant_vote1_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
510
+ relevant_btn2.click(relevant_vote2_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
511
+ relevant_btn3.click(relevant_vote3_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
512
+ relevant_btn4.click(relevant_vote4_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
513
 
514
  demo.launch(share=True)