Zonos / app.py
Steveeeeeeen's picture
Steveeeeeeen HF staff
Update app.py
d5d8bf3 verified
raw
history blame
3.87 kB
import torch
import torchaudio
import gradio as gr
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes
# Global cache to hold the loaded model
MODEL = None
device = "cuda"
def load_model():
"""
Loads the Zonos model once and caches it globally.
Adjust the model name if you want to switch from hybrid to transformer, etc.
"""
global MODEL
if MODEL is None:
model_name = "Zyphra/Zonos-v0.1-hybrid"
print(f"Loading model: {model_name}")
MODEL = Zonos.from_pretrained(model_name, device="cuda")
MODEL = MODEL.requires_grad_(False).eval()
MODEL.bfloat16() # optional if your GPU supports bfloat16
print("Model loaded successfully!")
return MODEL
def tts(text, speaker_audio, selected_language):
"""
text: str
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
selected_language: str (e.g., "en-us", "es-es", etc.)
Returns (sample_rate, waveform) for Gradio audio output.
"""
model = load_model()
# If no text, return None
if not text:
return None
# If no reference audio, return None
if speaker_audio is None:
return None
# Gradio provides audio in (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text, # The text prompt
speaker=spk_embedding, # Speaker embedding
language=selected_language, # Language from the Dropdown
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)")
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
# Add a dropdown for language selection
language_dropdown = gr.Dropdown(
label="Language",
# You can provide your own subset or use all:
# For demonstration, let's pick 5 common ones
# or you can do: choices=supported_language_codes
choices=["en-us", "es-es", "fr-fr", "de-de", "it"],
value="en-us",
interactive=True
)
generate_button = gr.Button("Generate")
# The output is an audio widget that Gradio will play
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
# Bind the generate button: pass text, reference audio, and selected language
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input, language_dropdown],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)