|
import gradio as gr |
|
import torch |
|
import uuid |
|
import json |
|
import librosa |
|
import os |
|
import tempfile |
|
import soundfile as sf |
|
import scipy.io.wavfile as wav |
|
|
|
from transformers import pipeline, VitsModel, AutoTokenizer, set_seed |
|
from nemo.collections.asr.models import EncDecMultiTaskModel |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') |
|
|
|
|
|
decode_cfg = canary_model.cfg.decoding |
|
decode_cfg.beam.beam_size = 1 |
|
canary_model.change_decoding_strategy(decode_cfg) |
|
|
|
|
|
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") |
|
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") |
|
|
|
|
|
def gen_text(audio_filepath, action): |
|
if audio_filepath is None: |
|
raise gr.Error("Please provide some input audio.") |
|
|
|
utt_id = uuid.uuid4() |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
data, sr = librosa.load(audio_filepath, sr=None, mono=True) |
|
if sr != SAMPLE_RATE: |
|
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") |
|
sf.write(converted_audio_filepath, data, SAMPLE_RATE) |
|
|
|
|
|
duration = len(data) / SAMPLE_RATE |
|
manifest_data = { |
|
"audio_filepath": converted_audio_filepath, |
|
"taskname": action, |
|
"source_lang": "en", |
|
"target_lang": "en" if action=="asr" else "fr", |
|
"pnc": "no", |
|
"answer": "predict", |
|
"duration": str(duration), |
|
} |
|
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") |
|
with open(manifest_filepath, 'w') as fout: |
|
fout.write(json.dumps(manifest_data)) |
|
|
|
if duration < 40: |
|
predicted_text = canary_model.transcribe(manifest_filepath)[0] |
|
else: |
|
predicted_text = get_buffered_pred_feat_multitaskAED( |
|
frame_asr, |
|
canary_model.cfg.preprocessor, |
|
model_stride_in_secs, |
|
canary_model.device, |
|
manifest=manifest_filepath, |
|
)[0].text |
|
|
|
return predicted_text |
|
|
|
|
|
def gen_speech(text): |
|
set_seed(555) |
|
input_text = tts_tokenizer(text, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = tts_model(**input_text) |
|
waveform_np = outputs.waveform[0].cpu().numpy() |
|
output_file = f"{str(uuid.uuid4())}.wav" |
|
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) |
|
return output_file |
|
|
|
|
|
def start_process(audio_filepath): |
|
transcription = gen_text(audio_filepath, "asr") |
|
print("Done transcribing") |
|
translation = gen_text(audio_filepath, "ast") |
|
print("Done translation") |
|
audio_output_filepath = gen_speech(transcription) |
|
print("Done speaking") |
|
return transcription, translation, audio_output_filepath |
|
|
|
|
|
|
|
playground = gr.Blocks() |
|
|
|
with playground: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio") |
|
with gr.Column(): |
|
translated_speech = gr.Audio(type="filepath", label="Generated Speech") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
transcipted_text = gr.Textbox(label="Transcription") |
|
with gr.Column(): |
|
translated_text = gr.Textbox(label="Translation") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
submit_button = gr.Button(value="Start Process", variant="primary") |
|
with gr.Column(): |
|
clear_button = gr.ClearButton(components=[input_audio, transcipted_text, translated_speech, translated_text], value="Clear") |
|
|
|
submit_button.click(start_process, inputs=[input_audio], outputs=[transcipted_text, translated_text, translated_speech]) |
|
|
|
playground.launch() |