Steveeeeeeen HF staff commited on
Commit
46f1390
·
verified ·
1 Parent(s): b1f1246

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -41
app.py CHANGED
@@ -5,53 +5,103 @@ import gradio as gr
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict
7
 
8
- model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda")
9
- model.bfloat16()
10
 
11
- def tts(text, reference_audio):
12
- if reference_audio is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  return None
14
-
15
- # Gradio returns (sample_rate, audio_data) for type="numpy"
16
- sr, wav_np = reference_audio
17
-
18
- # Convert NumPy audio data to Torch tensor
19
- wav_torch = torch.from_numpy(wav_np).float().unsqueeze(0)
20
- if wav_torch.dim() == 2 and wav_torch.shape[0] > wav_torch.shape[1]:
21
- wav_torch = wav_torch.T
22
-
23
- # Create speaker embedding
24
- spk_embedding = model.embed_spk_audio(wav_torch, sr)
25
-
26
- # Prepare conditioning
 
 
 
 
 
 
 
27
  cond_dict = make_cond_dict(
28
- text=text,
29
- speaker=spk_embedding.to(torch.bfloat16),
30
- language="en-us",
 
31
  )
32
  conditioning = model.prepare_conditioning(cond_dict)
33
-
34
- # Generate codes & decode
35
  with torch.no_grad():
36
- torch.manual_seed(421)
 
37
  codes = model.generate(conditioning)
38
-
39
- wavs = model.autoencoder.decode(codes).cpu()
40
- out_audio = wavs[0].numpy()
41
-
42
- # Return a tuple of (sample_rate, audio_data) for playback
43
- return (model.autoencoder.sampling_rate, out_audio)
44
-
45
- demo = gr.Interface(
46
- fn=tts,
47
- inputs=[
48
- gr.Textbox(label="Text to Synthesize"),
49
- gr.Audio(type="numpy", label="Reference Audio (Speaker)"),
50
- ],
51
- outputs=gr.Audio(label="Generated Audio"),
52
- title="Zonos TTS Demo (Hybrid)",
53
- description="Upload a reference audio for speaker embedding, enter text, and generate speech!"
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  if __name__ == "__main__":
57
- demo.launch(debug=True)
 
 
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict
7
 
8
+ # Global cache to hold the loaded model
9
+ MODEL = None
10
 
11
+ def load_model():
12
+ """
13
+ Loads the Zonos model once and caches it globally.
14
+ Adjust the model name to the one you want to use.
15
+ """
16
+ global MODEL
17
+ if MODEL is None:
18
+ model_name = "Zyphra/Zonos-v0.1-hybrid"
19
+ print(f"Loading model: {model_name}")
20
+ MODEL = Zonos.from_pretrained(model_name, device="cuda")
21
+ MODEL = MODEL.requires_grad_(False).eval()
22
+ MODEL.bfloat16() # optional, if your GPU supports bfloat16
23
+ print("Model loaded successfully!")
24
+ return MODEL
25
+
26
+ def tts(text, speaker_audio):
27
+ """
28
+ text: str
29
+ speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
30
+ Returns (sample_rate, waveform) for Gradio audio output.
31
+ """
32
+ model = load_model()
33
+
34
+ if not text:
35
  return None
36
+
37
+ # If the user hasn't provided any audio, just return None or a placeholder
38
+ if speaker_audio is None:
39
+ return None
40
+
41
+ # Gradio provides audio in the format (sample_rate, numpy_array)
42
+ sr, wav_np = speaker_audio
43
+
44
+ # Convert to Torch tensor: shape (1, num_samples)
45
+ wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
46
+ if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
47
+ # If shape is transposed, fix it
48
+ wav_tensor = wav_tensor.T
49
+
50
+ # Get speaker embedding
51
+ with torch.no_grad():
52
+ spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
53
+ spk_embedding = spk_embedding.to(model.device, dtype=torch.bfloat16)
54
+
55
+ # Prepare conditioning dictionary
56
  cond_dict = make_cond_dict(
57
+ text=text, # The text prompt
58
+ speaker=spk_embedding, # Speaker embedding from reference audio
59
+ language="en-us", # Hard-coded language or switch to another if needed
60
+ device=model.device,
61
  )
62
  conditioning = model.prepare_conditioning(cond_dict)
63
+
64
+ # Generate codes
65
  with torch.no_grad():
66
+ # Optionally set a manual seed for reproducibility
67
+ # torch.manual_seed(1234)
68
  codes = model.generate(conditioning)
69
+
70
+ # Decode the codes into raw audio
71
+ wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
72
+ sr_out = model.autoencoder.sampling_rate
73
+
74
+ return (sr_out, wav_out.numpy())
75
+
76
+ def build_demo():
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)")
79
+
80
+ with gr.Row():
81
+ text_input = gr.Textbox(
82
+ label="Text Prompt",
83
+ value="Hello from Zonos!",
84
+ lines=3
85
+ )
86
+ ref_audio_input = gr.Audio(
87
+ label="Reference Audio (Speaker Cloning)",
88
+ type="numpy"
89
+ )
90
+
91
+ generate_button = gr.Button("Generate")
92
+
93
+ # The output will be an audio widget that Gradio will play
94
+ audio_output = gr.Audio(label="Synthesized Output", type="numpy")
95
+
96
+ # Bind the generate button
97
+ generate_button.click(
98
+ fn=tts,
99
+ inputs=[text_input, ref_audio_input],
100
+ outputs=audio_output,
101
+ )
102
+
103
+ return demo
104
 
105
  if __name__ == "__main__":
106
+ demo_app = build_demo()
107
+ demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)