Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import uuid | |
import json | |
import librosa | |
import os | |
import tempfile | |
import soundfile as sf | |
import scipy.io.wavfile as wav | |
from transformers import pipeline, VitsModel, AutoTokenizer, set_seed | |
from nemo.collections.asr.models import EncDecMultiTaskModel | |
# Constants | |
SAMPLE_RATE = 16000 # Hz | |
# load ASR model | |
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') | |
# update dcode params | |
decode_cfg = canary_model.cfg.decoding | |
decode_cfg.beam.beam_size = 1 | |
canary_model.change_decoding_strategy(decode_cfg) | |
# load TTS model | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
# Function to convert audio to text using ASR | |
def transcribe(audio_filepath): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio.") | |
utt_id = uuid.uuid4() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Convert to 16 kHz | |
data, sr = librosa.load(audio_filepath, sr=None, mono=True) | |
if sr != SAMPLE_RATE: | |
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") | |
sf.write(converted_audio_filepath, data, SAMPLE_RATE) | |
# Transcribe audio | |
duration = len(data) / SAMPLE_RATE | |
manifest_data = { | |
"audio_filepath": converted_audio_filepath, | |
"taskname": "asr", | |
"source_lang": "en", | |
"target_lang": "en", | |
"pnc": "no", | |
"answer": "predict", | |
"duration": str(duration), | |
} | |
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") | |
with open(manifest_filepath, 'w') as fout: | |
fout.write(json.dumps(manifest_data)) | |
if duration < 40: | |
transcription = canary_model.transcribe(manifest_filepath)[0] | |
else: | |
transcription = get_buffered_pred_feat_multitaskAED( | |
frame_asr, | |
canary_model.cfg.preprocessor, | |
model_stride_in_secs, | |
canary_model.device, | |
manifest=manifest_filepath, | |
)[0].text | |
return transcription | |
# Function to convert text to speech using TTS | |
def gen_speech(text): | |
set_seed(555) # Make it deterministic | |
input_text = tts_tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = tts_model(**input_text) | |
waveform_np = outputs.waveform[0].cpu().numpy() | |
output_file = f"{str(uuid.uuid4())}.wav" | |
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) | |
return output_file | |
# Root function for Gradio interface | |
def start_process(audio_filepath): | |
transcription = transcribe(audio_filepath) | |
print("Done transcribing") | |
translation = "working in progress" | |
audio_output_filepath = gen_speech(transcription) | |
print("Done speaking") | |
return transcription, translation, audio_output_filepath | |
# Create Gradio interface | |
playground = gr.Blocks() | |
with playground: | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio") | |
transcipted_text = gr.Textbox(label="Transcription") | |
with gr.Column(): | |
translated_speech = gr.Audio(type="filepath", label="Generated Speech") | |
translated_text = gr.Textbox(label="Translation") | |
with gr.Row(): | |
with gr.Column(): | |
submit_button = gr.Button(value="Start Process", variant="primary") | |
with gr.Column(): | |
clear_button = gr.ClearButton(components=[input_audio, transcipted_text, translated_speech, translated_text], value="Clear") | |
submit_button.click(start_process, inputs=[input_audio], outputs=[transcipted_text, translated_speech, translated_text]) | |
playground.launch() |