mars5_space / app.py
arnavmehta7's picture
Update app.py
9f8a599 verified
raw
history blame contribute delete
No virus
5.5 kB
import gradio as gr
import torch
import librosa
from pathlib import Path
import tempfile, torchaudio
from transformers import pipeline
# Load the MARS5 model
mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
asr_model = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=30,
device=torch.device("cuda:0"),
)
def transcribe_file(f: str) -> str:
predictions = asr_model(f, return_timestamps=True)["chunks"]
print(f">>>>>. predictions: {predictions}")
return " ".join([prediction["text"] for prediction in predictions])
# Function to process the text and audio input and generate the synthesized output
def synthesize(text, audio_file, transcript, kwargs_dict):
print(f">>>>>>> Kwargs dict: {kwargs_dict}")
if not transcript:
transcript = transcribe_file(audio_file)
# Load the reference audio
wav, sr = librosa.load(audio_file, sr=mars5.sr, mono=True)
wav = torch.from_numpy(wav)
# Define the configuration for the TTS model
cfg = config_class(**kwargs_dict)
# Generate the synthesized audio
ar_codes, wav_out = mars5.tts(text, wav, transcript.strip(), cfg=cfg)
# Save the synthesized audio to a temporary file
output_path = Path(tempfile.mktemp(suffix=".wav"))
torchaudio.save(output_path, wav_out.unsqueeze(0), mars5.sr)
return str(output_path)
defaults = {
'temperature': 0.8,
'top_k': -1,
'top_p': 0.2,
'typical_p': 1.0,
'freq_penalty': 2.6,
'presence_penalty': 0.4,
'rep_penalty_window': 100,
'max_prompt_phones': 360,
'deep_clone': True,
'nar_guidance_w': 3
}
with gr.Blocks() as demo:
link = "https://github.com/Camb-ai/MARS5-TTS"
gr.Markdown("## MARS5 TTS Demo\nEnter text and upload an audio file to clone the voice and generate synthesized speech using **[MARS5-TTS]({link})**")
text = gr.Textbox(label="Text to synthesize")
audio_file = gr.Audio(label="Audio file to clone from", type="filepath")
generate_btn = gr.Button("Generate Synthesized Audio")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("additional inference settings\nWARNING: changing these incorrectly may degrade quality.")
prompt_text = gr.Textbox(label="Transcript of voice reference")
temperature = gr.Slider(minimum=0.01, maximum=3, step=0.01, label="temperature", value=defaults['temperature'])
top_k = gr.Slider(minimum=-1, maximum=2000, step=1, label="top_k", value=defaults['top_k'])
top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label="top_p", value=defaults['top_p'])
typical_p = gr.Slider(minimum=0.01, maximum=1, step=0.01, label="typical_p", value=defaults['typical_p'])
freq_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="freq_penalty", value=defaults['freq_penalty'])
presence_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="presence_penalty", value=defaults['presence_penalty'])
rep_penalty_window = gr.Slider(minimum=1, maximum=500, step=1, label="rep_penalty_window", value=defaults['rep_penalty_window'])
nar_guidance_w = gr.Slider(minimum=1, maximum=8, step=0.1, label="nar_guidance_w", value=defaults['nar_guidance_w'])
deep_clone = gr.Checkbox(value=defaults['deep_clone'], label='deep_clone')
output = gr.Audio(label="Synthesized Audio", type="filepath")
def on_click(
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
):
print(f">>>> transcript: {prompt_text}; audio_file = {audio_file}")
of = synthesize(
text,
audio_file,
prompt_text,
{
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'typical_p': typical_p,
'freq_penalty': freq_penalty,
'presence_penalty': presence_penalty,
'rep_penalty_window': rep_penalty_window,
'nar_guidance_w': nar_guidance_w,
'deep_clone': deep_clone
}
)
print(f">>>> output file: {of}")
return of
generate_btn.click(
on_click,
inputs=[
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
],
outputs=[output]
)
# Add examples
defaults = [0.8, -1, 0.2, 1.0, 2.6, 0.4, 100, 3, True]
examples = [
["Can you please go there and figure it out?", "female_speaker_1.flac", "People look, but no one ever finds it.", *defaults],
["Hey, do you need my help?", "male_speaker_1.flac", "Ask her to bring these things with her from the store.", *defaults]
]
gr.Examples(
examples=examples,
inputs=[text, audio_file, prompt_text, temperature, top_k, top_p, typical_p, freq_penalty, presence_penalty, rep_penalty_window, nar_guidance_w, deep_clone],
outputs=[output],
cache_examples=False,
fn=on_click
)
demo.launch(share=False)