File size: 4,691 Bytes
748ecaa a86425f 748ecaa d5d8bf3 748ecaa 15961ae e5d26e9 748ecaa 97132bd 15961ae 46f1390 15961ae 46f1390 15961ae 46f1390 15961ae 4a04525 15961ae 46f1390 15961ae 46f1390 15961ae d5d8bf3 46f1390 15961ae 46f1390 b1f1246 46f1390 15961ae 46f1390 e5d26e9 46f1390 748ecaa 15961ae e5d26e9 748ecaa d743fc1 46f1390 d743fc1 46f1390 97132bd 46f1390 97132bd 46f1390 15961ae d5d8bf3 15961ae cce1550 d5d8bf3 15961ae d5d8bf3 46f1390 15961ae 46f1390 748ecaa 46f1390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torchaudio
import gradio as gr
import spaces
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes
# We'll keep a global dictionary of loaded models to avoid reloading
MODELS_CACHE = {}
device = "cuda"
banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
def load_model(model_name: str):
"""
Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
"""
global MODELS_CACHE
if model_name not in MODELS_CACHE:
print(f"Loading model: {model_name}")
model = Zonos.from_pretrained(model_name, device=device)
model = model.requires_grad_(False).eval()
model.bfloat16() # optional if GPU supports bfloat16
MODELS_CACHE[model_name] = model
print(f"Model loaded successfully: {model_name}")
return MODELS_CACHE[model_name]
@spaces.GPU(duration=90)
def tts(text, speaker_audio, selected_language, model_choice):
"""
text: str (Text prompt to synthesize)
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
selected_language: str (language code)
model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
Returns (sample_rate, waveform) for Gradio audio output.
"""
# Load the selected model
model = load_model(model_choice)
if not text:
return None
if speaker_audio is None:
return None
# Gradio gives audio in the format (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text,
speaker=spk_embedding,
language=selected_language,
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks(theme='davehornik/Tealy') as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown("## Zonos-v0.1 TTS Demo")
gr.Markdown(
"""
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output.
> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone.
> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German.
"""
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
# Model dropdown
model_dropdown = gr.Dropdown(
label="Model Choice",
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
value="Zyphra/Zonos-v0.1-hybrid",
interactive=True,
)
# Language dropdown (you can filter or use all from supported_language_codes)
language_dropdown = gr.Dropdown(
label="Language Code",
choices=supported_language_codes,
value="en-us",
interactive=True,
)
generate_button = gr.Button("Generate")
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
|