Zonos / app.py
Steveeeeeeen's picture
Steveeeeeeen HF staff
Update app.py
a86425f verified
import torch
import torchaudio
import gradio as gr
import spaces
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes
# We'll keep a global dictionary of loaded models to avoid reloading
MODELS_CACHE = {}
device = "cuda"
banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
def load_model(model_name: str):
"""
Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
"""
global MODELS_CACHE
if model_name not in MODELS_CACHE:
print(f"Loading model: {model_name}")
model = Zonos.from_pretrained(model_name, device=device)
model = model.requires_grad_(False).eval()
model.bfloat16() # optional if GPU supports bfloat16
MODELS_CACHE[model_name] = model
print(f"Model loaded successfully: {model_name}")
return MODELS_CACHE[model_name]
@spaces.GPU(duration=90)
def tts(text, speaker_audio, selected_language, model_choice):
"""
text: str (Text prompt to synthesize)
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
selected_language: str (language code)
model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
Returns (sample_rate, waveform) for Gradio audio output.
"""
# Load the selected model
model = load_model(model_choice)
if not text:
return None
if speaker_audio is None:
return None
# Gradio gives audio in the format (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text,
speaker=spk_embedding,
language=selected_language,
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks(theme='davehornik/Tealy') as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown("## Zonos-v0.1 TTS Demo")
gr.Markdown(
"""
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output.
> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone.
> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German.
"""
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
# Model dropdown
model_dropdown = gr.Dropdown(
label="Model Choice",
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
value="Zyphra/Zonos-v0.1-hybrid",
interactive=True,
)
# Language dropdown (you can filter or use all from supported_language_codes)
language_dropdown = gr.Dropdown(
label="Language Code",
choices=supported_language_codes,
value="en-us",
interactive=True,
)
generate_button = gr.Button("Generate")
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)