Siddhant's picture
Update app.py
aaf3fcb verified
raw
history blame
5.08 kB
import gradio as gr
import soundfile
import time
import torch
import scipy.io.wavfile
from espnet2.utils.types import str_or_none
from espnet2.bin.asr_inference import Speech2Text
from subprocess import call
import os
from espnet_model_zoo.downloader import ModelDownloader
# print(a1)
# exit()
# exit()
# tagen = 'kan-bayashi/ljspeech_vits'
# vocoder_tagen = "none"
speech2text_slurp = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|ner|> <|SLURP|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
nbest=1
)
speech2text_fsc = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|ic|> <|fsc|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
nbest=1
)
speech2text_grabo = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|nl|> <|scr|> <|grabo_scr|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
nbest=1
)
def inference(wav,data):
with torch.no_grad():
if data == "english_slurp":
speech, rate = soundfile.read(wav.name)
nbests = speech2text_slurp(speech)
text, *_ = nbests[0]
# intent=text.split(" ")[0]
# scenario=intent.split("_")[0]
# action=intent.split("_")[1]
# text="{scenario: "+scenario+", action: "+action+"}"
elif data == "english_fsc":
print(wav.name)
speech, rate = soundfile.read(wav.name)
print(speech.shape)
if len(speech.shape)==2:
speech=speech[:,0]
# soundfile.write("store_file.wav", speech, rate, subtype='FLOAT')
print(speech.shape)
nbests = speech2text_fsc(speech)
text, *_ = nbests[0]
# intent=text.split(" ")[0]
# action=intent.split("_")[0]
# objects=intent.split("_")[1]
# location=intent.split("_")[2]
# text="{action: "+action+", object: "+objects+", location: "+location+"}"
# elif data == "english_snips":
# print(wav.name)
# speech, rate = soundfile.read(wav.name)
# nbests = speech2text_snips(speech)
# text, *_ = nbests[0]
elif data == "dutch":
print(wav.name)
speech, rate = soundfile.read(wav.name)
nbests = speech2text_grabo(speech)
text, *_ = nbests[0]
# intent=text.split(" ")[0]
# action=intent.split("_")[0]
# objects=intent.split("_")[1]
# location=intent.split("_")[2]
# text="{action: "+action+", object: "+objects+", location: "+location+"}"
# if lang == "chinese":
# wav = text2speechch(text)["wav"]
# scipy.io.wavfile.write("out.wav",text2speechch.fs , wav.view(-1).cpu().numpy())
# if lang == "japanese":
# wav = text2speechjp(text)["wav"]
# scipy.io.wavfile.write("out.wav",text2speechjp.fs , wav.view(-1).cpu().numpy())
return text
title = "UniverSLU"
description = "Gradio demo for UniverSLU: Universal Spoken Language Understanding for Diverse Tasks with Natural Language Instructions. To use it, simply record your audio or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://github.com/espnet/espnet' target='_blank'>Github Repo</a></p>"
examples=[['audio_slurp.flac',"english_slurp"],['audio_fsc.wav',"english_fsc"],['audio_grabo.wav',"dutch"]]
# gr.inputs.Textbox(label="input text",lines=10),gr.inputs.Radio(choices=["english"], type="value", default="english", label="language")
gr.Interface(
inference,
[gr.inputs.Audio(label="input audio",source = "microphone", type="file"),gr.inputs.Radio(choices=["english_slurp","english_fsc","dutch_scd"], type="value", default="english_fsc", label="Task")],
gr.outputs.Textbox(type="str", label="Output"),
title=title,
description=description,
article=article,
enable_queue=True,
examples=examples
).launch(debug=True)