AkhilTolani commited on
Commit
e658e7c
1 Parent(s): 09acd34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -177
app.py CHANGED
@@ -1,27 +1,23 @@
1
- import math
2
- from queue import Queue
3
- from threading import Thread
4
- from typing import Optional
5
-
6
- import numpy as np
7
  import spaces
8
  import gradio as gr
9
  import torch
 
 
 
 
10
 
11
  from parler_tts import ParlerTTSForConditionalGeneration
12
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
13
- from transformers.generation.streamers import BaseStreamer
14
 
15
- device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
16
- torch_dtype = torch.float16 if device != "cpu" else torch.float32
17
 
18
  custom_repo_id = "AkhilTolani/vocals-english"
19
 
20
- custom_model = ParlerTTSForConditionalGeneration.from_pretrained(custom_repo_id).to(device)
21
 
22
  tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
23
 
24
- SEED = 456
25
 
26
  default_text = "Raindrops on the window pane, mirroring my tears again. Autumn leaves are falling down, just like my world without you around. They called you different, a love forbidden, but in your eyes I saw my future written."
27
  default_description = "A man delivers his speech in a quiet, enclosed space with exceptional clarity, maintaining a very monotone tone of voice, at a relatively slow pace. His pitch is slightly low."
@@ -29,167 +25,44 @@ default_description = "A man delivers his speech in a quiet, enclosed space with
29
  examples = [
30
  [
31
  "Raindrops on the window pane, mirroring my tears again. Autumn leaves are falling down, just like my world without you around. They called you different, a love forbidden, but in your eyes I saw my future written.",
32
- "'A woman speaks with a somewhat monotone tone, delivering her words at a moderate pace, in a recording that sounds quite clear but slightly confined. Her voice has a slightly high pitch.'",
33
- 10.0,
34
- ],
35
  ]
36
 
37
-
38
- class ParlerTTSStreamer(BaseStreamer):
39
- def __init__(
40
- self,
41
- model: ParlerTTSForConditionalGeneration,
42
- device: Optional[str] = None,
43
- play_steps: Optional[int] = 10,
44
- stride: Optional[int] = None,
45
- timeout: Optional[float] = None,
46
- ):
47
- """
48
- Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
49
- useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
50
- Gradio demo).
51
- Parameters:
52
- model (`ParlerTTSForConditionalGeneration`):
53
- The Parler-TTS model used to generate the audio waveform.
54
- device (`str`, *optional*):
55
- The torch device on which to run the computation. If `None`, will default to the device of the model.
56
- play_steps (`int`, *optional*, defaults to 10):
57
- The number of generation steps with which to return the generated audio array. Using fewer steps will
58
- mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
59
- should be tuned to your device and latency requirements.
60
- stride (`int`, *optional*):
61
- The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
62
- the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
63
- play_steps // 6 in the audio space.
64
- timeout (`int`, *optional*):
65
- The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
66
- in `.generate()`, when it is called in a separate thread.
67
- """
68
- self.decoder = model.decoder
69
- self.audio_encoder = model.audio_encoder
70
- self.generation_config = model.generation_config
71
- self.device = device if device is not None else model.device
72
-
73
- # variables used in the streaming process
74
- self.play_steps = play_steps
75
- if stride is not None:
76
- self.stride = stride
77
- else:
78
- hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
79
- self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
80
- self.token_cache = None
81
- self.to_yield = 0
82
-
83
- # varibles used in the thread process
84
- self.audio_queue = Queue()
85
- self.stop_signal = None
86
- self.timeout = timeout
87
-
88
- def apply_delay_pattern_mask(self, input_ids):
89
- # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
90
- _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
91
- input_ids[:, :1],
92
- bos_token_id=self.generation_config.bos_token_id,
93
- pad_token_id=self.generation_config.decoder_start_token_id,
94
- max_length=input_ids.shape[-1],
95
- )
96
- # apply the pattern mask to the input ids
97
- input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
98
-
99
- # revert the pattern delay mask by filtering the pad token id
100
- mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
101
- input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
102
- # append the frame dimension back to the audio codes
103
- input_ids = input_ids[None, ...]
104
-
105
- # send the input_ids to the correct device
106
- input_ids = input_ids.to(self.audio_encoder.device)
107
-
108
- decode_sequentially = (
109
- self.generation_config.bos_token_id in input_ids
110
- or self.generation_config.pad_token_id in input_ids
111
- or self.generation_config.eos_token_id in input_ids
112
- )
113
- if not decode_sequentially:
114
- output_values = self.audio_encoder.decode(
115
- input_ids,
116
- audio_scales=[None],
117
- )
118
- else:
119
- sample = input_ids[:, 0]
120
- sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
121
- sample = sample[:, :, sample_mask]
122
- output_values = self.audio_encoder.decode(sample[None, ...], [None])
123
-
124
- audio_values = output_values.audio_values[0, 0]
125
- return audio_values.cpu().float().numpy()
126
-
127
- def put(self, value):
128
- batch_size = value.shape[0] // self.decoder.num_codebooks
129
- if batch_size > 1:
130
- raise ValueError("ParlerTTSStreamer only supports batch size 1")
131
-
132
- if self.token_cache is None:
133
- self.token_cache = value
134
- else:
135
- self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
136
-
137
- if self.token_cache.shape[-1] % self.play_steps == 0:
138
- audio_values = self.apply_delay_pattern_mask(self.token_cache)
139
- self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
140
- self.to_yield += len(audio_values) - self.to_yield - self.stride
141
-
142
- def end(self):
143
- """Flushes any remaining cache and appends the stop symbol."""
144
- if self.token_cache is not None:
145
- audio_values = self.apply_delay_pattern_mask(self.token_cache)
146
- else:
147
- audio_values = np.zeros(self.to_yield)
148
-
149
- self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
150
-
151
- def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
152
- """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
153
- self.audio_queue.put(audio, timeout=self.timeout)
154
- if stream_end:
155
- self.audio_queue.put(self.stop_signal, timeout=self.timeout)
156
-
157
- def __iter__(self):
158
- return self
159
-
160
- def __next__(self):
161
- value = self.audio_queue.get(timeout=self.timeout)
162
- if not isinstance(value, np.ndarray) and value == self.stop_signal:
163
- raise StopIteration()
164
- else:
165
- return value
166
-
167
-
168
- sampling_rate = custom_model.audio_encoder.config.sampling_rate
169
- frame_rate = custom_model.audio_encoder.config.frame_rate
170
 
171
  @spaces.GPU
172
- def generate_base(text, description, play_steps_in_s=2.0):
173
- play_steps = int(frame_rate * play_steps_in_s)
174
- streamer = ParlerTTSStreamer(custom_model, device=device, play_steps=play_steps)
175
-
176
  inputs = tokenizer(description, return_tensors="pt").to(device)
177
- prompt = tokenizer(text, return_tensors="pt").to(device)
178
 
179
- generation_kwargs = dict(
180
- input_ids=inputs.input_ids,
181
- prompt_input_ids=prompt.input_ids,
182
- streamer=streamer,
183
- min_length=20,
184
  )
 
185
 
186
- set_seed(SEED)
187
- thread = Thread(target=custom_model.generate, kwargs=generation_kwargs)
188
- thread.start()
189
 
190
- for new_audio in streamer:
191
- print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
192
- yield sampling_rate, new_audio
193
 
194
  css = """
195
  #share-btn-container {
@@ -256,20 +129,18 @@ with gr.Blocks(css=css) as block:
256
  </p>
257
  """
258
  )
259
- with gr.Tab("Vocals"):
260
- with gr.Row():
261
- with gr.Column():
262
- input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
263
- description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
264
- play_seconds = gr.Slider(3.0, 15.0, value=10.0, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
265
- run_button = gr.Button("Generate Audio", variant="primary")
266
- with gr.Column():
267
- audio_out = gr.Audio(label="Parler-TTS + Vocals", type="numpy", elem_id="audio_out", streaming=True, autoplay=True)
268
-
269
- inputs = [input_text, description, play_seconds]
270
- outputs = [audio_out]
271
- gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
272
- run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
273
 
274
  block.queue()
275
  block.launch(share=True)
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
5
+ from string import punctuation
6
+ import re
7
+
8
 
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
 
11
 
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
13
 
14
  custom_repo_id = "AkhilTolani/vocals-english"
15
 
16
+ model = ParlerTTSForConditionalGeneration.from_pretrained(custom_repo_id).to(device)
17
 
18
  tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
19
 
20
+ SEED = 42
21
 
22
  default_text = "Raindrops on the window pane, mirroring my tears again. Autumn leaves are falling down, just like my world without you around. They called you different, a love forbidden, but in your eyes I saw my future written."
23
  default_description = "A man delivers his speech in a quiet, enclosed space with exceptional clarity, maintaining a very monotone tone of voice, at a relatively slow pace. His pitch is slightly low."
 
25
  examples = [
26
  [
27
  "Raindrops on the window pane, mirroring my tears again. Autumn leaves are falling down, just like my world without you around. They called you different, a love forbidden, but in your eyes I saw my future written.",
28
+ "A woman speaks with a somewhat monotone tone, delivering her words at a moderate pace, in a recording that sounds quite clear but slightly confined. Her voice has a slightly high pitch.",
29
+ ]
 
30
  ]
31
 
32
+ number_normalizer = EnglishNumberNormalizer()
33
+
34
+ def preprocess(text):
35
+ text = number_normalizer(text).strip()
36
+ text = text.replace("-", " ")
37
+ if text[-1] not in punctuation:
38
+ text = f"{text}."
39
+
40
+ abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
41
+
42
+ def separate_abb(chunk):
43
+ chunk = chunk.replace(".","")
44
+ print(chunk)
45
+ return " ".join(chunk)
46
+
47
+ abbreviations = re.findall(abbreviations_pattern, text)
48
+ for abv in abbreviations:
49
+ if abv in text:
50
+ text = text.replace(abv, separate_abb(abv))
51
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  @spaces.GPU
54
+ def gen_tts(text, description):
 
 
 
55
  inputs = tokenizer(description, return_tensors="pt").to(device)
56
+ prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
57
 
58
+ set_seed(SEED)
59
+ generation = model.generate(
60
+ input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, do_sample=True, temperature=1.0
 
 
61
  )
62
+ audio_arr = generation.cpu().numpy().squeeze()
63
 
64
+ return audio_arr
 
 
65
 
 
 
 
66
 
67
  css = """
68
  #share-btn-container {
 
129
  </p>
130
  """
131
  )
132
+ with gr.Row():
133
+ with gr.Column():
134
+ input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
135
+ description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
136
+ run_button = gr.Button("Generate Audio", variant="primary")
137
+ with gr.Column():
138
+ audio_out = gr.Audio(label="Parler-TTS + Vocals", type="numpy", elem_id="audio_out")
139
+
140
+ inputs = [input_text, description]
141
+ outputs = [audio_out]
142
+ gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
143
+ run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
 
 
144
 
145
  block.queue()
146
  block.launch(share=True)