Steveeeeeeen HF staff commited on
Commit
d5d8bf3
·
verified ·
1 Parent(s): ab5fd90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -3,7 +3,7 @@ import torchaudio
3
  import gradio as gr
4
 
5
  from zonos.model import Zonos
6
- from zonos.conditioning import make_cond_dict
7
 
8
  # Global cache to hold the loaded model
9
  MODEL = None
@@ -12,7 +12,7 @@ device = "cuda"
12
  def load_model():
13
  """
14
  Loads the Zonos model once and caches it globally.
15
- Adjust the model name to the one you want to use.
16
  """
17
  global MODEL
18
  if MODEL is None:
@@ -20,26 +20,29 @@ def load_model():
20
  print(f"Loading model: {model_name}")
21
  MODEL = Zonos.from_pretrained(model_name, device="cuda")
22
  MODEL = MODEL.requires_grad_(False).eval()
23
- MODEL.bfloat16() # optional, if your GPU supports bfloat16
24
  print("Model loaded successfully!")
25
  return MODEL
26
 
27
- def tts(text, speaker_audio):
28
  """
29
  text: str
30
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
 
 
31
  Returns (sample_rate, waveform) for Gradio audio output.
32
  """
33
  model = load_model()
34
 
 
35
  if not text:
36
  return None
37
 
38
- # If the user hasn't provided any audio, just return None or a placeholder
39
  if speaker_audio is None:
40
  return None
41
 
42
- # Gradio provides audio in the format (sample_rate, numpy_array)
43
  sr, wav_np = speaker_audio
44
 
45
  # Convert to Torch tensor: shape (1, num_samples)
@@ -55,17 +58,15 @@ def tts(text, speaker_audio):
55
 
56
  # Prepare conditioning dictionary
57
  cond_dict = make_cond_dict(
58
- text=text, # The text prompt
59
- speaker=spk_embedding, # Speaker embedding from reference audio
60
- language="en-us", # Hard-coded language or switch to another if needed
61
  device=device,
62
  )
63
  conditioning = model.prepare_conditioning(cond_dict)
64
 
65
  # Generate codes
66
  with torch.no_grad():
67
- # Optionally set a manual seed for reproducibility
68
- # torch.manual_seed(1234)
69
  codes = model.generate(conditioning)
70
 
71
  # Decode the codes into raw audio
@@ -76,7 +77,7 @@ def tts(text, speaker_audio):
76
 
77
  def build_demo():
78
  with gr.Blocks() as demo:
79
- gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)")
80
 
81
  with gr.Row():
82
  text_input = gr.Textbox(
@@ -88,16 +89,26 @@ def build_demo():
88
  label="Reference Audio (Speaker Cloning)",
89
  type="numpy"
90
  )
 
 
 
 
 
 
 
 
 
 
91
 
92
  generate_button = gr.Button("Generate")
93
 
94
- # The output will be an audio widget that Gradio will play
95
  audio_output = gr.Audio(label="Synthesized Output", type="numpy")
96
 
97
- # Bind the generate button
98
  generate_button.click(
99
  fn=tts,
100
- inputs=[text_input, ref_audio_input],
101
  outputs=audio_output,
102
  )
103
 
 
3
  import gradio as gr
4
 
5
  from zonos.model import Zonos
6
+ from zonos.conditioning import make_cond_dict, supported_language_codes
7
 
8
  # Global cache to hold the loaded model
9
  MODEL = None
 
12
  def load_model():
13
  """
14
  Loads the Zonos model once and caches it globally.
15
+ Adjust the model name if you want to switch from hybrid to transformer, etc.
16
  """
17
  global MODEL
18
  if MODEL is None:
 
20
  print(f"Loading model: {model_name}")
21
  MODEL = Zonos.from_pretrained(model_name, device="cuda")
22
  MODEL = MODEL.requires_grad_(False).eval()
23
+ MODEL.bfloat16() # optional if your GPU supports bfloat16
24
  print("Model loaded successfully!")
25
  return MODEL
26
 
27
+ def tts(text, speaker_audio, selected_language):
28
  """
29
  text: str
30
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
31
+ selected_language: str (e.g., "en-us", "es-es", etc.)
32
+
33
  Returns (sample_rate, waveform) for Gradio audio output.
34
  """
35
  model = load_model()
36
 
37
+ # If no text, return None
38
  if not text:
39
  return None
40
 
41
+ # If no reference audio, return None
42
  if speaker_audio is None:
43
  return None
44
 
45
+ # Gradio provides audio in (sample_rate, numpy_array)
46
  sr, wav_np = speaker_audio
47
 
48
  # Convert to Torch tensor: shape (1, num_samples)
 
58
 
59
  # Prepare conditioning dictionary
60
  cond_dict = make_cond_dict(
61
+ text=text, # The text prompt
62
+ speaker=spk_embedding, # Speaker embedding
63
+ language=selected_language, # Language from the Dropdown
64
  device=device,
65
  )
66
  conditioning = model.prepare_conditioning(cond_dict)
67
 
68
  # Generate codes
69
  with torch.no_grad():
 
 
70
  codes = model.generate(conditioning)
71
 
72
  # Decode the codes into raw audio
 
77
 
78
  def build_demo():
79
  with gr.Blocks() as demo:
80
+ gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)")
81
 
82
  with gr.Row():
83
  text_input = gr.Textbox(
 
89
  label="Reference Audio (Speaker Cloning)",
90
  type="numpy"
91
  )
92
+ # Add a dropdown for language selection
93
+ language_dropdown = gr.Dropdown(
94
+ label="Language",
95
+ # You can provide your own subset or use all:
96
+ # For demonstration, let's pick 5 common ones
97
+ # or you can do: choices=supported_language_codes
98
+ choices=["en-us", "es-es", "fr-fr", "de-de", "it"],
99
+ value="en-us",
100
+ interactive=True
101
+ )
102
 
103
  generate_button = gr.Button("Generate")
104
 
105
+ # The output is an audio widget that Gradio will play
106
  audio_output = gr.Audio(label="Synthesized Output", type="numpy")
107
 
108
+ # Bind the generate button: pass text, reference audio, and selected language
109
  generate_button.click(
110
  fn=tts,
111
+ inputs=[text_input, ref_audio_input, language_dropdown],
112
  outputs=audio_output,
113
  )
114