Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import uuid | |
import json | |
import librosa | |
import os | |
import tempfile | |
import soundfile as sf | |
import scipy.io.wavfile as wav | |
from transformers import pipeline, VitsModel, AutoTokenizer, set_seed | |
from nemo.collections.asr.models import EncDecMultiTaskModel | |
# Constants | |
SAMPLE_RATE = 16000 # Hz | |
# load ASR model | |
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') | |
# update dcode params | |
decode_cfg = canary_model.cfg.decoding | |
decode_cfg.beam.beam_size = 1 | |
canary_model.change_decoding_strategy(decode_cfg) | |
# load TTS model | |
# tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
# tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
# Function to convert audio to text using ASR | |
def gen_text(audio_filepath, action, source_lang, target_lang): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio.") | |
utt_id = uuid.uuid4() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Convert to 16 kHz | |
data, sr = librosa.load(audio_filepath, sr=None, mono=True) | |
if sr != SAMPLE_RATE: | |
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") | |
sf.write(converted_audio_filepath, data, SAMPLE_RATE) | |
# Transcribe audio | |
duration = len(data) / SAMPLE_RATE | |
manifest_data = { | |
"audio_filepath": converted_audio_filepath, | |
"taskname": action, | |
"source_lang": source_lang, | |
"target_lang": source_lang if action=="asr" else target_lang, | |
"pnc": "no", | |
"answer": "predict", | |
"duration": str(duration), | |
} | |
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") | |
with open(manifest_filepath, 'w') as fout: | |
fout.write(json.dumps(manifest_data)) | |
predicted_text = canary_model.transcribe(manifest_filepath)[0] | |
# if duration < 40: | |
# predicted_text = canary_model.transcribe(manifest_filepath)[0] | |
# else: | |
# predicted_text = get_buffered_pred_feat_multitaskAED( | |
# frame_asr, | |
# canary_model.cfg.preprocessor, | |
# model_stride_in_secs, | |
# canary_model.device, | |
# manifest=manifest_filepath, | |
# )[0].text | |
return predicted_text | |
# Function to convert text to speech using TTS | |
def gen_speech(text, lang): | |
set_seed(555) # Make it deterministic | |
match lang: | |
case "en": | |
model = "facebook/mms-tts-eng" | |
case "fr": | |
model = "facebook/mms-tts-fra" | |
case "de": | |
model = "facebook/mms-tts-deu" | |
case "es": | |
model = "facebook/mms-tts-spa" | |
case _: | |
model = "facebook/mms-tts-eng" | |
# if lang=="en": | |
# model = "facebook/mms-tts-eng" | |
# elif lang=="fr": | |
# model = "facebook/mms-tts-fra" | |
# load TTS model | |
tts_model = VitsModel.from_pretrained(model) | |
tts_tokenizer = AutoTokenizer.from_pretrained(model) | |
input_text = tts_tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = tts_model(**input_text) | |
waveform_np = outputs.waveform[0].cpu().numpy() | |
output_file = f"{str(uuid.uuid4())}.wav" | |
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) | |
return output_file | |
# Root function for Gradio interface | |
def start_process(audio_filepath, source_lang, target_lang): | |
transcription = gen_text(audio_filepath, "asr", source_lang, target_lang) | |
print("Done transcribing") | |
translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang) | |
print("Done translation") | |
audio_output_filepath = gen_speech(translation, target_lang) | |
print("Done speaking") | |
return transcription, translation, audio_output_filepath | |
# Create Gradio interface | |
playground = gr.Blocks() | |
with playground: | |
with gr.Row(): | |
with gr.Column(): | |
source_lang = gr.Dropdown( | |
choices=["en", "de", "es", "fr"], value="en", label="Source Language" | |
) | |
with gr.Column(): | |
target_lang = gr.Dropdown( | |
choices=["en", "de", "es", "fr"], value="fr", label="Target Language" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio") | |
with gr.Column(): | |
translated_speech = gr.Audio(type="filepath", label="Generated Speech") | |
with gr.Row(): | |
with gr.Column(): | |
transcipted_text = gr.Textbox(label="Transcription") | |
with gr.Column(): | |
translated_text = gr.Textbox(label="Translation") | |
with gr.Row(): | |
with gr.Column(): | |
submit_button = gr.Button(value="Start Process", variant="primary") | |
with gr.Column(): | |
clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear") | |
# with gr.Row(): | |
# gr.Examples( | |
# examples=["sample.wav"], | |
# inputs=[input_audio], | |
# outputs=[transcipted_text, translated_speech, translated_text], | |
# run_on_click=True, cache_examples=True, fn=start_process | |
# ) | |
submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech]) | |
playground.launch() |