|
import gradio as gr |
|
import torch |
|
import uuid |
|
import json |
|
import librosa |
|
import os |
|
import tempfile |
|
import soundfile as sf |
|
import scipy.io.wavfile as wav |
|
|
|
from transformers import pipeline, VitsModel, AutoTokenizer, set_seed |
|
from nemo.collections.asr.models import EncDecMultiTaskModel |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') |
|
decode_cfg = canary_model.cfg.decoding |
|
decode_cfg.beam.beam_size = 1 |
|
canary_model.change_decoding_strategy(decode_cfg) |
|
|
|
|
|
def gen_text(audio_filepath, action, source_lang, target_lang): |
|
if audio_filepath is None: |
|
raise gr.Error("Please provide some input audio.") |
|
|
|
utt_id = uuid.uuid4() |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
data, sr = librosa.load(audio_filepath, sr=None, mono=True) |
|
if sr != SAMPLE_RATE: |
|
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") |
|
sf.write(converted_audio_filepath, data, SAMPLE_RATE) |
|
|
|
|
|
duration = len(data) / SAMPLE_RATE |
|
manifest_data = { |
|
"audio_filepath": converted_audio_filepath, |
|
"taskname": action, |
|
"source_lang": source_lang, |
|
"target_lang": source_lang if action=="asr" else target_lang, |
|
"pnc": "no", |
|
"answer": "predict", |
|
"duration": str(duration), |
|
} |
|
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") |
|
with open(manifest_filepath, 'w') as fout: |
|
fout.write(json.dumps(manifest_data)) |
|
|
|
predicted_text = canary_model.transcribe(manifest_filepath)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return predicted_text |
|
|
|
|
|
def gen_speech(text, lang): |
|
set_seed(555) |
|
match lang: |
|
case "en": |
|
model = "facebook/mms-tts-eng" |
|
case "fr": |
|
model = "facebook/mms-tts-fra" |
|
case "de": |
|
model = "facebook/mms-tts-deu" |
|
case "es": |
|
model = "facebook/mms-tts-spa" |
|
case _: |
|
model = "facebook/mms-tts-eng" |
|
|
|
|
|
tts_model = VitsModel.from_pretrained(model) |
|
tts_tokenizer = AutoTokenizer.from_pretrained(model) |
|
|
|
input_text = tts_tokenizer(text, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = tts_model(**input_text) |
|
waveform_np = outputs.waveform[0].cpu().numpy() |
|
output_file = f"{str(uuid.uuid4())}.wav" |
|
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) |
|
return output_file |
|
|
|
|
|
def start_process(audio_filepath, source_lang, target_lang): |
|
transcription = gen_text(audio_filepath, "asr", source_lang, target_lang) |
|
print("Done transcribing") |
|
translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang) |
|
print("Done translation") |
|
audio_output_filepath = gen_speech(translation, target_lang) |
|
print("Done speaking") |
|
return transcription, translation, audio_output_filepath |
|
|
|
|
|
|
|
playground = gr.Blocks() |
|
|
|
with playground: |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
## Your AI Translate Assistant |
|
### Gets input audio from user, transcribe and translate it. Convert back to speech. |
|
- category: Automatic Speech Recognition, model: [nvidia/canary-1b](https://huggingface.co/nvidia/canary-1b) |
|
- category: Text-to-Speech, model: [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
source_lang = gr.Dropdown( |
|
choices=["en", "de", "es", "fr"], value="en", label="Source Language" |
|
) |
|
with gr.Column(): |
|
target_lang = gr.Dropdown( |
|
choices=["en", "de", "es", "fr"], value="fr", label="Target Language" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio") |
|
with gr.Column(): |
|
translated_speech = gr.Audio(type="filepath", label="Generated Speech") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
transcipted_text = gr.Textbox(label="Transcription") |
|
with gr.Column(): |
|
translated_text = gr.Textbox(label="Translation") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
submit_button = gr.Button(value="Start Process", variant="primary") |
|
with gr.Column(): |
|
clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear") |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
["sample_en.wav","en","fr"], |
|
["sample_fr.wav","fr","de"], |
|
["sample_de.wav","de","es"], |
|
["sample_es.wav","es","en"] |
|
], |
|
inputs=[input_audio, source_lang, target_lang], |
|
outputs=[transcipted_text, translated_text, translated_speech], |
|
run_on_click=True, cache_examples=True, fn=start_process |
|
) |
|
|
|
submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech]) |
|
|
|
playground.launch() |