File size: 3,384 Bytes
748ecaa
 
 
 
 
d743fc1
748ecaa
46f1390
 
e5d26e9
748ecaa
46f1390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f1246
46f1390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5d26e9
46f1390
 
748ecaa
46f1390
 
 
e5d26e9
748ecaa
d743fc1
46f1390
 
d743fc1
46f1390
 
d743fc1
46f1390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748ecaa
 
46f1390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torchaudio
import gradio as gr

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict

# Global cache to hold the loaded model
MODEL = None
device = "cuda"

def load_model():
    """
    Loads the Zonos model once and caches it globally.
    Adjust the model name to the one you want to use.
    """
    global MODEL
    if MODEL is None:
        model_name = "Zyphra/Zonos-v0.1-hybrid"
        print(f"Loading model: {model_name}")
        MODEL = Zonos.from_pretrained(model_name, device="cuda")
        MODEL = MODEL.requires_grad_(False).eval()
        MODEL.bfloat16()  # optional, if your GPU supports bfloat16
        print("Model loaded successfully!")
    return MODEL

def tts(text, speaker_audio):
    """
    text: str
    speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
    Returns (sample_rate, waveform) for Gradio audio output.
    """
    model = load_model()

    if not text:
        return None

    # If the user hasn't provided any audio, just return None or a placeholder
    if speaker_audio is None:
        return None

    # Gradio provides audio in the format (sample_rate, numpy_array)
    sr, wav_np = speaker_audio

    # Convert to Torch tensor: shape (1, num_samples)
    wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
    if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
        # If shape is transposed, fix it
        wav_tensor = wav_tensor.T

    # Get speaker embedding
    with torch.no_grad():
        spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
        spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)

    # Prepare conditioning dictionary
    cond_dict = make_cond_dict(
        text=text,                # The text prompt
        speaker=spk_embedding,    # Speaker embedding from reference audio
        language="en-us",         # Hard-coded language or switch to another if needed
        device=device,
    )
    conditioning = model.prepare_conditioning(cond_dict)

    # Generate codes
    with torch.no_grad():
        # Optionally set a manual seed for reproducibility
        # torch.manual_seed(1234)
        codes = model.generate(conditioning)

    # Decode the codes into raw audio
    wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
    sr_out = model.autoencoder.sampling_rate

    return (sr_out, wav_out.numpy())

def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)")

        with gr.Row():
            text_input = gr.Textbox(
                label="Text Prompt",
                value="Hello from Zonos!",
                lines=3
            )
            ref_audio_input = gr.Audio(
                label="Reference Audio (Speaker Cloning)",
                type="numpy"
            )

        generate_button = gr.Button("Generate")

        # The output will be an audio widget that Gradio will play
        audio_output = gr.Audio(label="Synthesized Output", type="numpy")

        # Bind the generate button
        generate_button.click(
            fn=tts,
            inputs=[text_input, ref_audio_input],
            outputs=audio_output,
        )

    return demo

if __name__ == "__main__":
    demo_app = build_demo()
    demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)