Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import tempfile | |
import gradio as gr | |
from neon_tts_plugin_coqui import CoquiTTS | |
import os | |
import time | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from flores200_codes import flores_codes | |
pipe = pipeline(model="Yuyang2022/yue") # change to "your-username/the-name-you-picked" | |
LANGUAGES = list(CoquiTTS.langs.keys()) | |
coquiTTS = CoquiTTS() | |
def audio_tts(audio, language:str, lang): | |
text = pipe(audio)["text"] | |
text = translation("zho_Hant", lang, text) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: | |
coquiTTS.get_tts(text, fp, speaker = {"language" : language}) | |
return fp.name | |
def load_models(): | |
# build model and tokenizer | |
model_name_dict = { | |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M", | |
} | |
model_dict = {} | |
for call_name, real_name in model_name_dict.items(): | |
print("\tLoading model: %s" % call_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
tokenizer = AutoTokenizer.from_pretrained(real_name) | |
model_dict[call_name + "_model"] = model | |
model_dict[call_name + "_tokenizer"] = tokenizer | |
return model_dict | |
def translation(source, target, text): | |
if len(model_dict) == 2: | |
model_name = "nllb-distilled-600M" | |
start_time = time.time() | |
source = "zho_Hant" #flores_codes[source] | |
target = flores_codes[target] | |
model = model_dict[model_name + "_model"] | |
tokenizer = model_dict[model_name + "_tokenizer"] | |
translator = pipeline( | |
"translation", | |
model=model, | |
tokenizer=tokenizer, | |
src_lang=source, | |
tgt_lang=target, | |
) | |
output = translator(text, max_length=400) | |
end_time = time.time() | |
output = output[0]["translation_text"] | |
result = { | |
"inference_time": end_time - start_time, | |
"source": source, | |
"target": target, | |
"result": output, | |
} | |
return output | |
if __name__ == "__main__": | |
#print("\tinit models") | |
global model_dict | |
model_dict = load_models() | |
lang_codes = list(flores_codes.keys()) | |
# define gradio demo | |
inputs = [gr.Audio(source="microphone", type="filepath"), | |
gr.Radio( | |
label="Target text Language", | |
choices=LANGUAGES, value="en"), | |
gr.inputs.Dropdown(lang_codes, default="English", label="Target text Language"),] | |
outputs = gr.Audio(label="Output") | |
demo = gr.Interface(fn=audio_tts, inputs=inputs, outputs=outputs, | |
title="translation - speech to speech", | |
description="Realtime demo for speech translation.",) | |
demo.launch() |