AkhilTolani commited on
Commit
1f26343
1 Parent(s): ebccdd9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from queue import Queue
3
+ from threading import Thread
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import spaces
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from parler_tts import ParlerTTSForConditionalGeneration
12
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
13
+ from transformers.generation.streamers import BaseStreamer
14
+
15
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
16
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
17
+
18
+ repo_id = "parler-tts/parler_tts_mini_v0.1"
19
+ custom_repo_id = "AkhilTolani/vocals"
20
+
21
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
22
+ repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
23
+ ).to(device)
24
+ custom_model = ParlerTTSForConditionalGeneration.from_pretrained(
25
+ custom_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
26
+ ).to(device)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
29
+ feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
30
+
31
+ SAMPLE_RATE = feature_extractor.sampling_rate
32
+ SEED = 777
33
+
34
+ default_text = "ooooh, please surprise me and sing whatever you would like to sing, ooh."
35
+
36
+ examples = [
37
+ [
38
+ "In the quiet of the night, Shadows dance beneath the moonlight, Whispers of a love that's gone, Echoes of a heart withdrawn.",
39
+ "A women speaks at an average pace with a slightly animated delivery in a very confined sounding environment with clear audio quality.",
40
+ 15.0,
41
+ ],
42
+ [
43
+ "'Lost in the memories, Of the days that used to be, A fragile heart in a storm, Yearning for a love reborn.",
44
+ "A man speaks in quite a monotone voice at a slightly faster-than-average pace in a confined space with very clear audio.",
45
+ 15.0,
46
+ ],
47
+ [
48
+ "Tears fall like autumn rain, Tracing lines of hidden pain, In the silence, I still hear, The ghost of you always near.",
49
+ "A man delivers her words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
50
+ 15.0,
51
+ ],
52
+ [
53
+ "Where did we go wrong? In the story we've outgrown, Now I'm left to sing alone, In a world so cold and stone.",
54
+ "A woman delivers her words at a fast pace and an animated tone, in a very spacious environment, in a very clear voice and audio quality.",
55
+ 15.0,
56
+ ],
57
+ ]
58
+
59
+
60
+ class ParlerTTSStreamer(BaseStreamer):
61
+ def __init__(
62
+ self,
63
+ model: ParlerTTSForConditionalGeneration,
64
+ device: Optional[str] = None,
65
+ play_steps: Optional[int] = 10,
66
+ stride: Optional[int] = None,
67
+ timeout: Optional[float] = None,
68
+ ):
69
+ """
70
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
71
+ useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
72
+ Gradio demo).
73
+ Parameters:
74
+ model (`ParlerTTSForConditionalGeneration`):
75
+ The Parler-TTS model used to generate the audio waveform.
76
+ device (`str`, *optional*):
77
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
78
+ play_steps (`int`, *optional*, defaults to 10):
79
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
80
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
81
+ should be tuned to your device and latency requirements.
82
+ stride (`int`, *optional*):
83
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
84
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
85
+ play_steps // 6 in the audio space.
86
+ timeout (`int`, *optional*):
87
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
88
+ in `.generate()`, when it is called in a separate thread.
89
+ """
90
+ self.decoder = model.decoder
91
+ self.audio_encoder = model.audio_encoder
92
+ self.generation_config = model.generation_config
93
+ self.device = device if device is not None else model.device
94
+
95
+ # variables used in the streaming process
96
+ self.play_steps = play_steps
97
+ if stride is not None:
98
+ self.stride = stride
99
+ else:
100
+ hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
101
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
102
+ self.token_cache = None
103
+ self.to_yield = 0
104
+
105
+ # varibles used in the thread process
106
+ self.audio_queue = Queue()
107
+ self.stop_signal = None
108
+ self.timeout = timeout
109
+
110
+ def apply_delay_pattern_mask(self, input_ids):
111
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
112
+ _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
113
+ input_ids[:, :1],
114
+ bos_token_id=self.generation_config.bos_token_id,
115
+ pad_token_id=self.generation_config.decoder_start_token_id,
116
+ max_length=input_ids.shape[-1],
117
+ )
118
+ # apply the pattern mask to the input ids
119
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
120
+
121
+ # revert the pattern delay mask by filtering the pad token id
122
+ mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
123
+ input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
124
+ # append the frame dimension back to the audio codes
125
+ input_ids = input_ids[None, ...]
126
+
127
+ # send the input_ids to the correct device
128
+ input_ids = input_ids.to(self.audio_encoder.device)
129
+
130
+ decode_sequentially = (
131
+ self.generation_config.bos_token_id in input_ids
132
+ or self.generation_config.pad_token_id in input_ids
133
+ or self.generation_config.eos_token_id in input_ids
134
+ )
135
+ if not decode_sequentially:
136
+ output_values = self.audio_encoder.decode(
137
+ input_ids,
138
+ audio_scales=[None],
139
+ )
140
+ else:
141
+ sample = input_ids[:, 0]
142
+ sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
143
+ sample = sample[:, :, sample_mask]
144
+ output_values = self.audio_encoder.decode(sample[None, ...], [None])
145
+
146
+ audio_values = output_values.audio_values[0, 0]
147
+ return audio_values.cpu().float().numpy()
148
+
149
+ def put(self, value):
150
+ batch_size = value.shape[0] // self.decoder.num_codebooks
151
+ if batch_size > 1:
152
+ raise ValueError("ParlerTTSStreamer only supports batch size 1")
153
+
154
+ if self.token_cache is None:
155
+ self.token_cache = value
156
+ else:
157
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
158
+
159
+ if self.token_cache.shape[-1] % self.play_steps == 0:
160
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
161
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
162
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
163
+
164
+ def end(self):
165
+ """Flushes any remaining cache and appends the stop symbol."""
166
+ if self.token_cache is not None:
167
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
168
+ else:
169
+ audio_values = np.zeros(self.to_yield)
170
+
171
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
172
+
173
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
174
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
175
+ self.audio_queue.put(audio, timeout=self.timeout)
176
+ if stream_end:
177
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
178
+
179
+ def __iter__(self):
180
+ return self
181
+
182
+ def __next__(self):
183
+ value = self.audio_queue.get(timeout=self.timeout)
184
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
185
+ raise StopIteration()
186
+ else:
187
+ return value
188
+
189
+
190
+ sampling_rate = model.audio_encoder.config.sampling_rate
191
+ frame_rate = model.audio_encoder.config.frame_rate
192
+
193
+ @spaces.GPU
194
+ def generate_base(text, description, play_steps_in_s=2.0):
195
+ play_steps = int(frame_rate * play_steps_in_s)
196
+ streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
197
+
198
+ inputs = tokenizer(description, return_tensors="pt").to(device)
199
+ prompt = tokenizer(text, return_tensors="pt").to(device)
200
+
201
+ generation_kwargs = dict(
202
+ input_ids=inputs.input_ids,
203
+ prompt_input_ids=prompt.input_ids,
204
+ streamer=streamer,
205
+ do_sample=True,
206
+ temperature=1.0,
207
+ min_new_tokens=10,
208
+ )
209
+
210
+ set_seed(SEED)
211
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
212
+ thread.start()
213
+
214
+ for new_audio in streamer:
215
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
216
+ yield sampling_rate, new_audio
217
+
218
+ @spaces.GPU
219
+ def generate_custom(text, description, play_steps_in_s=2.0):
220
+ play_steps = int(frame_rate * play_steps_in_s)
221
+ streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
222
+
223
+ inputs = tokenizer(description, return_tensors="pt").to(device)
224
+ prompt = tokenizer(text, return_tensors="pt").to(device)
225
+
226
+ generation_kwargs = dict(
227
+ input_ids=inputs.input_ids,
228
+ prompt_input_ids=prompt.input_ids,
229
+ streamer=streamer,
230
+ do_sample=True,
231
+ temperature=1.0,
232
+ min_new_tokens=10,
233
+ )
234
+
235
+ set_seed(SEED)
236
+ thread = Thread(target=custom_model.generate, kwargs=generation_kwargs)
237
+ thread.start()
238
+
239
+ for new_audio in streamer:
240
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
241
+ yield sampling_rate, new_audio
242
+
243
+
244
+ css = """
245
+ #share-btn-container {
246
+ display: flex;
247
+ padding-left: 0.5rem !important;
248
+ padding-right: 0.5rem !important;
249
+ background-color: #000000;
250
+ justify-content: center;
251
+ align-items: center;
252
+ border-radius: 9999px !important;
253
+ width: 13rem;
254
+ margin-top: 10px;
255
+ margin-left: auto;
256
+ flex: unset !important;
257
+ }
258
+ #share-btn {
259
+ all: initial;
260
+ color: #ffffff;
261
+ font-weight: 600;
262
+ cursor: pointer;
263
+ font-family: 'IBM Plex Sans', sans-serif;
264
+ margin-left: 0.5rem !important;
265
+ padding-top: 0.25rem !important;
266
+ padding-bottom: 0.25rem !important;
267
+ right:0;
268
+ }
269
+ #share-btn * {
270
+ all: unset !important;
271
+ }
272
+ #share-btn-container div:nth-child(-n+2){
273
+ width: auto !important;
274
+ min-height: 0px !important;
275
+ }
276
+ #share-btn-container .wrap {
277
+ display: none !important;
278
+ }
279
+ """
280
+ with gr.Blocks(css=css) as block:
281
+ gr.HTML(
282
+ """
283
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
284
+ <div
285
+ style="
286
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
287
+ "
288
+ >
289
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
290
+ Parler-TTS 🗣️
291
+ </h1>
292
+ </div>
293
+ </div>
294
+ """
295
+ )
296
+ gr.HTML(
297
+ f"""
298
+ <p><a href="https://github.com/huggingface/parler-tts"> Parler-TTS + Vocals</a> is a training and inference library for
299
+ high-fidelity text-to-speech (TTS) models. Generates high-quality vocals with features that can be controlled using a simple text prompt (e.g. gender, background noise, speaking rate, pitch and reverberation).</p>
300
+
301
+ <p>Tips for ensuring good generation:
302
+ <ul>
303
+ <li>Include the term <b>"very clear audio"</b> to generate the highest quality audio, and "very noisy audio" for high levels of background noise</li>
304
+ <li>Punctuation can be used to control the prosody of the generations, e.g. use commas to add small breaks in speech</li>
305
+ <li>The remaining speech features (gender, speaking rate, pitch and reverberation) can be controlled directly through the prompt</li>
306
+ </ul>
307
+ </p>
308
+ """
309
+ )
310
+ with gr.Tab("Base"):
311
+ with gr.Row():
312
+ with gr.Column():
313
+ input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
314
+ description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
315
+ play_seconds = gr.Slider(3.0, 10.0, value=5.0, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
316
+ run_button = gr.Button("Generate Audio", variant="primary")
317
+ with gr.Column():
318
+ audio_out = gr.Audio(label="Parler-TTS + Vocals", type="numpy", elem_id="audio_out", streaming=True, autoplay=True)
319
+
320
+ inputs = [input_text, description, play_seconds]
321
+ outputs = [audio_out]
322
+ gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
323
+ run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
324
+
325
+ block.queue()
326
+ block.launch(share=True)