|
import torch |
|
import torchaudio |
|
import gradio as gr |
|
|
|
from zonos.model import Zonos |
|
from zonos.conditioning import make_cond_dict, supported_language_codes |
|
|
|
|
|
MODEL = None |
|
device = "cuda" |
|
|
|
def load_model(): |
|
""" |
|
Loads the Zonos model once and caches it globally. |
|
Adjust the model name if you want to switch from hybrid to transformer, etc. |
|
""" |
|
global MODEL |
|
if MODEL is None: |
|
model_name = "Zyphra/Zonos-v0.1-hybrid" |
|
print(f"Loading model: {model_name}") |
|
MODEL = Zonos.from_pretrained(model_name, device="cuda") |
|
MODEL = MODEL.requires_grad_(False).eval() |
|
MODEL.bfloat16() |
|
print("Model loaded successfully!") |
|
return MODEL |
|
|
|
def tts(text, speaker_audio, selected_language): |
|
""" |
|
text: str |
|
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" |
|
selected_language: str (e.g., "en-us", "es-es", etc.) |
|
|
|
Returns (sample_rate, waveform) for Gradio audio output. |
|
""" |
|
model = load_model() |
|
|
|
|
|
if not text: |
|
return None |
|
|
|
|
|
if speaker_audio is None: |
|
return None |
|
|
|
|
|
sr, wav_np = speaker_audio |
|
|
|
|
|
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float() |
|
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]: |
|
|
|
wav_tensor = wav_tensor.T |
|
|
|
|
|
with torch.no_grad(): |
|
spk_embedding = model.make_speaker_embedding(wav_tensor, sr) |
|
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) |
|
|
|
|
|
cond_dict = make_cond_dict( |
|
text=text, |
|
speaker=spk_embedding, |
|
language=selected_language, |
|
device=device, |
|
) |
|
conditioning = model.prepare_conditioning(cond_dict) |
|
|
|
|
|
with torch.no_grad(): |
|
codes = model.generate(conditioning) |
|
|
|
|
|
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() |
|
sr_out = model.autoencoder.sampling_rate |
|
|
|
return (sr_out, wav_out.numpy()) |
|
|
|
def build_demo(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)") |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Text Prompt", |
|
value="Hello from Zonos!", |
|
lines=3 |
|
) |
|
ref_audio_input = gr.Audio( |
|
label="Reference Audio (Speaker Cloning)", |
|
type="numpy" |
|
) |
|
|
|
language_dropdown = gr.Dropdown( |
|
label="Language", |
|
|
|
|
|
|
|
choices=["en-us", "es-es", "fr-fr", "de-de", "it"], |
|
value="en-us", |
|
interactive=True |
|
) |
|
|
|
generate_button = gr.Button("Generate") |
|
|
|
|
|
audio_output = gr.Audio(label="Synthesized Output", type="numpy") |
|
|
|
|
|
generate_button.click( |
|
fn=tts, |
|
inputs=[text_input, ref_audio_input, language_dropdown], |
|
outputs=audio_output, |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo_app = build_demo() |
|
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|