|
import torch |
|
import torchaudio |
|
import gradio as gr |
|
import spaces |
|
|
|
from zonos.model import Zonos |
|
from zonos.conditioning import make_cond_dict, supported_language_codes |
|
|
|
|
|
MODELS_CACHE = {} |
|
device = "cuda" |
|
|
|
banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png" |
|
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>' |
|
|
|
|
|
def load_model(model_name: str): |
|
""" |
|
Loads or retrieves a cached Zonos model, sets it to eval and bfloat16. |
|
""" |
|
global MODELS_CACHE |
|
if model_name not in MODELS_CACHE: |
|
print(f"Loading model: {model_name}") |
|
model = Zonos.from_pretrained(model_name, device=device) |
|
model = model.requires_grad_(False).eval() |
|
model.bfloat16() |
|
MODELS_CACHE[model_name] = model |
|
print(f"Model loaded successfully: {model_name}") |
|
return MODELS_CACHE[model_name] |
|
|
|
@spaces.GPU(duration=90) |
|
def tts(text, speaker_audio, selected_language, model_choice): |
|
""" |
|
text: str (Text prompt to synthesize) |
|
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" |
|
selected_language: str (language code) |
|
model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid") |
|
|
|
Returns (sample_rate, waveform) for Gradio audio output. |
|
""" |
|
|
|
model = load_model(model_choice) |
|
|
|
if not text: |
|
return None |
|
if speaker_audio is None: |
|
return None |
|
|
|
|
|
sr, wav_np = speaker_audio |
|
|
|
|
|
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float() |
|
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]: |
|
|
|
wav_tensor = wav_tensor.T |
|
|
|
|
|
with torch.no_grad(): |
|
spk_embedding = model.make_speaker_embedding(wav_tensor, sr) |
|
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) |
|
|
|
|
|
cond_dict = make_cond_dict( |
|
text=text, |
|
speaker=spk_embedding, |
|
language=selected_language, |
|
device=device, |
|
) |
|
conditioning = model.prepare_conditioning(cond_dict) |
|
|
|
|
|
with torch.no_grad(): |
|
codes = model.generate(conditioning) |
|
|
|
|
|
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() |
|
sr_out = model.autoencoder.sampling_rate |
|
|
|
return (sr_out, wav_out.numpy()) |
|
|
|
def build_demo(): |
|
with gr.Blocks(theme='davehornik/Tealy') as demo: |
|
gr.HTML(BANNER, elem_id="banner") |
|
gr.Markdown("## Zonos-v0.1 TTS Demo") |
|
gr.Markdown( |
|
""" |
|
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output. |
|
|
|
> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone. |
|
|
|
> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German. |
|
""" |
|
) |
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Text Prompt", |
|
value="Hello from Zonos!", |
|
lines=3 |
|
) |
|
ref_audio_input = gr.Audio( |
|
label="Reference Audio (Speaker Cloning)", |
|
type="numpy" |
|
) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
label="Model Choice", |
|
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"], |
|
value="Zyphra/Zonos-v0.1-hybrid", |
|
interactive=True, |
|
) |
|
|
|
language_dropdown = gr.Dropdown( |
|
label="Language Code", |
|
choices=supported_language_codes, |
|
value="en-us", |
|
interactive=True, |
|
) |
|
|
|
generate_button = gr.Button("Generate") |
|
audio_output = gr.Audio(label="Synthesized Output", type="numpy") |
|
|
|
generate_button.click( |
|
fn=tts, |
|
inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown], |
|
outputs=audio_output, |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo_app = build_demo() |
|
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|