ylacombe commited on
Commit
fe7399d
1 Parent(s): f513635

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -84
app.py CHANGED
@@ -9,22 +9,26 @@ import numpy as np
9
  import spaces
10
  import gradio as gr
11
  import torch
 
 
12
 
13
  from parler_tts import ParlerTTSForConditionalGeneration
14
  from pydub import AudioSegment
15
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
16
 
17
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
18
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
19
 
20
  repo_id = "ai4bharat/indic-parler-tts-pretrained"
21
- jenny_repo_id = "ai4bharat/indic-parler-tts"
22
 
23
  model = ParlerTTSForConditionalGeneration.from_pretrained(
24
  repo_id, attn_implementation="eager", torch_dtype=torch_dtype,
25
  ).to(device)
26
- jenny_model = ParlerTTSForConditionalGeneration.from_pretrained(
27
- jenny_repo_id, attn_implementation="eager", torch_dtype=torch_dtype,
28
  ).to(device)
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
@@ -89,7 +93,7 @@ examples = [
89
  ]
90
 
91
 
92
- jenny_examples = [
93
  [
94
  "मुले बागेत खेळत आहेत आणि पक्षी किलबिलाट करत आहेत.",
95
  "Sunita speaks slowly in a calm, moderate-pitched voice, delivering the news with a neutral tone. The recording is very high quality with no background noise.",
@@ -171,44 +175,30 @@ def numpy_to_mp3(audio_array, sampling_rate):
171
  sampling_rate = model.audio_encoder.config.sampling_rate
172
  frame_rate = model.audio_encoder.config.frame_rate
173
 
174
- # @spaces.GPU
175
- # def generate_base(text, description, play_steps_in_s=2.0):
176
- # play_steps = int(frame_rate * play_steps_in_s)
177
- # streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
178
-
179
- # inputs = description_tokenizer(description, return_tensors="pt").to(device)
180
- # prompt = tokenizer(text, return_tensors="pt").to(device)
181
-
182
- # generation_kwargs = dict(
183
- # input_ids=inputs.input_ids,
184
- # prompt_input_ids=prompt.input_ids,
185
- # streamer=streamer,
186
- # do_sample=True,
187
- # temperature=1.0,
188
- # min_new_tokens=10,
189
- # )
190
-
191
- # set_seed(SEED)
192
- # thread = Thread(target=model.generate, kwargs=generation_kwargs)
193
- # thread.start()
194
-
195
- # for new_audio in streamer:
196
- # print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
197
- # yield numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
198
-
199
  @spaces.GPU
200
- def generate_base(text, description, play_steps_in_s=2.0):
201
  # Initialize variables
202
- play_steps = int(frame_rate * play_steps_in_s)
203
- chunk_size = 15 # Process 10 words at a time
204
 
205
  # Tokenize the full text and description
206
  inputs = description_tokenizer(description, return_tensors="pt").to(device)
207
-
208
- # Split text into chunks of approximately 10 words
209
- words = text.split()
210
- chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
211
-
 
 
 
 
 
 
 
 
 
 
 
 
212
  all_audio = []
213
 
214
  # Process each chunk
@@ -223,8 +213,6 @@ def generate_base(text, description, play_steps_in_s=2.0):
223
  prompt_input_ids=prompt.input_ids,
224
  prompt_attention_mask=prompt.attention_mask,
225
  do_sample=True,
226
- # temperature=1.0,
227
- # min_new_tokens=10,
228
  return_dict_in_generate=True
229
  )
230
 
@@ -243,43 +231,30 @@ def generate_base(text, description, play_steps_in_s=2.0):
243
  print(f"Sample of length: {round(combined_audio.shape[0] / sampling_rate, 2)} seconds")
244
  yield numpy_to_mp3(combined_audio, sampling_rate=sampling_rate)
245
 
246
- # @spaces.GPU
247
- # def generate_jenny(text, description, play_steps_in_s=2.0):
248
- # play_steps = int(frame_rate * play_steps_in_s)
249
- # streamer = ParlerTTSStreamer(jenny_model, device=device, play_steps=play_steps)
250
-
251
- # inputs = description_tokenizer(description, return_tensors="pt").to(device)
252
- # prompt = tokenizer(text, return_tensors="pt").to(device)
253
-
254
- # generation_kwargs = dict(
255
- # input_ids=inputs.input_ids,
256
- # prompt_input_ids=prompt.input_ids,
257
- # streamer=streamer,
258
- # do_sample=True,
259
- # temperature=1.0,
260
- # min_new_tokens=10,
261
- # )
262
-
263
- # set_seed(SEED)
264
- # thread = Thread(target=jenny_model.generate, kwargs=generation_kwargs)
265
- # thread.start()
266
-
267
- # for new_audio in streamer:
268
- # print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
269
- # yield sampling_rate, new_audio
270
 
271
  @spaces.GPU
272
- def generate_jenny(text, description, play_steps_in_s=2.0):
273
  # Initialize variables
274
- play_steps = int(frame_rate * play_steps_in_s)
275
- chunk_size = 15 # Process 10 words at a time
276
 
277
  # Tokenize the full text and description
278
  inputs = description_tokenizer(description, return_tensors="pt").to(device)
279
-
280
- # Split text into chunks of approximately 10 words
281
- words = text.split()
282
- chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  all_audio = []
285
 
@@ -289,14 +264,12 @@ def generate_jenny(text, description, play_steps_in_s=2.0):
289
  prompt = tokenizer(chunk, return_tensors="pt").to(device)
290
 
291
  # Generate audio for the chunk
292
- generation = jenny_model.generate(
293
  input_ids=inputs.input_ids,
294
  attention_mask=inputs.attention_mask,
295
  prompt_input_ids=prompt.input_ids,
296
  prompt_attention_mask=prompt.attention_mask,
297
  do_sample=True,
298
- # temperature=1.0,
299
- # min_new_tokens=10,
300
  return_dict_in_generate=True
301
  )
302
 
@@ -387,29 +360,27 @@ with gr.Blocks(css=css) as block:
387
  with gr.Tab("Finetuned"):
388
  with gr.Row():
389
  with gr.Column():
390
- input_text = gr.Textbox(label="Input Text", lines=2, value=jenny_examples[0][0], elem_id="input_text")
391
- description = gr.Textbox(label="Description", lines=2, value=jenny_examples[0][1], elem_id="input_description")
392
- play_seconds = gr.Slider(3.0, 7.0, value=jenny_examples[0][2], step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
393
  run_button = gr.Button("Generate Audio", variant="primary")
394
  with gr.Column():
395
- audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
396
 
397
- inputs = [input_text, description, play_seconds]
398
  outputs = [audio_out]
399
- gr.Examples(examples=jenny_examples, fn=generate_jenny, inputs=inputs, outputs=outputs, cache_examples=False)
400
- run_button.click(fn=generate_jenny, inputs=inputs, outputs=outputs, queue=True)
401
 
402
  with gr.Tab("Pretrained"):
403
  with gr.Row():
404
  with gr.Column():
405
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
406
  description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
407
- play_seconds = gr.Slider(3.0, 7.0, value=3.0, step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
408
  run_button = gr.Button("Generate Audio", variant="primary")
409
  with gr.Column():
410
- audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
411
 
412
- inputs = [input_text, description, play_seconds]
413
  outputs = [audio_out]
414
  gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
415
  run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
 
9
  import spaces
10
  import gradio as gr
11
  import torch
12
+ import nltk
13
+
14
 
15
  from parler_tts import ParlerTTSForConditionalGeneration
16
  from pydub import AudioSegment
17
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
18
 
19
+ nltk.download('punkt_tab')
20
+
21
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
23
 
24
  repo_id = "ai4bharat/indic-parler-tts-pretrained"
25
+ finetuned_repo_id = "ai4bharat/indic-parler-tts"
26
 
27
  model = ParlerTTSForConditionalGeneration.from_pretrained(
28
  repo_id, attn_implementation="eager", torch_dtype=torch_dtype,
29
  ).to(device)
30
+ finetuned_model = ParlerTTSForConditionalGeneration.from_pretrained(
31
+ finetuned_repo_id, attn_implementation="eager", torch_dtype=torch_dtype,
32
  ).to(device)
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
93
  ]
94
 
95
 
96
+ finetuned_examples = [
97
  [
98
  "मुले बागेत खेळत आहेत आणि पक्षी किलबिलाट करत आहेत.",
99
  "Sunita speaks slowly in a calm, moderate-pitched voice, delivering the news with a neutral tone. The recording is very high quality with no background noise.",
 
175
  sampling_rate = model.audio_encoder.config.sampling_rate
176
  frame_rate = model.audio_encoder.config.frame_rate
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  @spaces.GPU
179
+ def generate_base(text, description,):
180
  # Initialize variables
181
+ chunk_size = 25 # Process max 25 words or a sentence at a time
 
182
 
183
  # Tokenize the full text and description
184
  inputs = description_tokenizer(description, return_tensors="pt").to(device)
185
+
186
+ sentences_text = nltk.sent_tokenize(text) # this gives us a list of sentences
187
+ curr_sentence = ""
188
+ chunks = []
189
+ for sentence in sentences_text:
190
+ candidate = " ".join([curr_sentence, sentence])
191
+ if len(candidate.split()) >= chunk_size:
192
+ chunks.append(curr_sentence)
193
+ curr_sentence = sentence
194
+ else:
195
+ curr_sentence = candidate
196
+
197
+ if curr_sentence != "":
198
+ chunks.append(curr_sentence)
199
+
200
+ print(chunks)
201
+
202
  all_audio = []
203
 
204
  # Process each chunk
 
213
  prompt_input_ids=prompt.input_ids,
214
  prompt_attention_mask=prompt.attention_mask,
215
  do_sample=True,
 
 
216
  return_dict_in_generate=True
217
  )
218
 
 
231
  print(f"Sample of length: {round(combined_audio.shape[0] / sampling_rate, 2)} seconds")
232
  yield numpy_to_mp3(combined_audio, sampling_rate=sampling_rate)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  @spaces.GPU
236
+ def generate_finetuned(text, description):
237
  # Initialize variables
238
+ chunk_size = 25 # Process max 25 words or a sentence at a time
 
239
 
240
  # Tokenize the full text and description
241
  inputs = description_tokenizer(description, return_tensors="pt").to(device)
242
+
243
+ sentences_text = nltk.sent_tokenize(text) # this gives us a list of sentences
244
+ curr_sentence = ""
245
+ chunks = []
246
+ for sentence in sentences_text:
247
+ candidate = " ".join([curr_sentence, sentence])
248
+ if len(candidate.split()) >= chunk_size:
249
+ chunks.append(curr_sentence)
250
+ curr_sentence = sentence
251
+ else:
252
+ curr_sentence = candidate
253
+
254
+ if curr_sentence != "":
255
+ chunks.append(curr_sentence)
256
+
257
+ print(chunks)
258
 
259
  all_audio = []
260
 
 
264
  prompt = tokenizer(chunk, return_tensors="pt").to(device)
265
 
266
  # Generate audio for the chunk
267
+ generation = finetuned_model.generate(
268
  input_ids=inputs.input_ids,
269
  attention_mask=inputs.attention_mask,
270
  prompt_input_ids=prompt.input_ids,
271
  prompt_attention_mask=prompt.attention_mask,
272
  do_sample=True,
 
 
273
  return_dict_in_generate=True
274
  )
275
 
 
360
  with gr.Tab("Finetuned"):
361
  with gr.Row():
362
  with gr.Column():
363
+ input_text = gr.Textbox(label="Input Text", lines=2, value=finetuned_examples[0][0], elem_id="input_text")
364
+ description = gr.Textbox(label="Description", lines=2, value=finetuned_examples[0][1], elem_id="input_description")
 
365
  run_button = gr.Button("Generate Audio", variant="primary")
366
  with gr.Column():
367
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", autoplay=True)
368
 
369
+ inputs = [input_text, description]
370
  outputs = [audio_out]
371
+ gr.Examples(examples=finetuned_examples, fn=generate_finetuned, inputs=inputs, outputs=outputs, cache_examples=False)
372
+ run_button.click(fn=generate_finetuned, inputs=inputs, outputs=outputs, queue=True)
373
 
374
  with gr.Tab("Pretrained"):
375
  with gr.Row():
376
  with gr.Column():
377
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
378
  description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
 
379
  run_button = gr.Button("Generate Audio", variant="primary")
380
  with gr.Column():
381
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", autoplay=True)
382
 
383
+ inputs = [input_text, description]
384
  outputs = [audio_out]
385
  gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
386
  run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)