sanchit-gandhi commited on
Commit
496bf8a
·
1 Parent(s): 33d12bd

yield mp3 bytes

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +31 -5
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📝
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import math
2
  from queue import Queue
3
  from threading import Thread
@@ -9,6 +10,7 @@ import gradio as gr
9
  import torch
10
 
11
  from parler_tts import ParlerTTSForConditionalGeneration
 
12
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
13
  from transformers.generation.streamers import BaseStreamer
14
 
@@ -208,6 +210,30 @@ class ParlerTTSStreamer(BaseStreamer):
208
  else:
209
  return value
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  sampling_rate = model.audio_encoder.config.sampling_rate
213
  frame_rate = model.audio_encoder.config.frame_rate
@@ -235,7 +261,7 @@ def generate_base(text, description, play_steps_in_s=2.0):
235
 
236
  for new_audio in streamer:
237
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
238
- yield sampling_rate, new_audio
239
 
240
  @spaces.GPU
241
  def generate_jenny(text, description, play_steps_in_s=2.0):
@@ -338,10 +364,10 @@ with gr.Blocks(css=css) as block:
338
  with gr.Column():
339
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
340
  description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
341
- play_seconds = gr.Slider(3.0, 5.0, value=3.0, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
342
  run_button = gr.Button("Generate Audio", variant="primary")
343
  with gr.Column():
344
- audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out", streaming=True, autoplay=True)
345
 
346
  inputs = [input_text, description, play_seconds]
347
  outputs = [audio_out]
@@ -353,10 +379,10 @@ with gr.Blocks(css=css) as block:
353
  with gr.Column():
354
  input_text = gr.Textbox(label="Input Text", lines=2, value=jenny_examples[0][0], elem_id="input_text")
355
  description = gr.Textbox(label="Description", lines=2, value=jenny_examples[0][1], elem_id="input_description")
356
- play_seconds = gr.Slider(3.0, 5.0, value=jenny_examples[0][2], step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
357
  run_button = gr.Button("Generate Audio", variant="primary")
358
  with gr.Column():
359
- audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out", streaming=True, autoplay=True)
360
 
361
  inputs = [input_text, description, play_seconds]
362
  outputs = [audio_out]
 
1
+ import io
2
  import math
3
  from queue import Queue
4
  from threading import Thread
 
10
  import torch
11
 
12
  from parler_tts import ParlerTTSForConditionalGeneration
13
+ from pydub import AudioSegment
14
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
15
  from transformers.generation.streamers import BaseStreamer
16
 
 
210
  else:
211
  return value
212
 
213
+ def numpy_to_mp3(audio_array, sampling_rate):
214
+ # Normalize audio_array if it's floating-point
215
+ if np.issubdtype(audio_array.dtype, np.floating):
216
+ max_val = np.max(np.abs(audio_array))
217
+ audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
218
+ audio_array = audio_array.astype(np.int16)
219
+
220
+ # Create an audio segment from the numpy array
221
+ audio_segment = AudioSegment(
222
+ audio_array.tobytes(),
223
+ frame_rate=sampling_rate,
224
+ sample_width=audio_array.dtype.itemsize,
225
+ channels=1
226
+ )
227
+
228
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
229
+ mp3_io = io.BytesIO()
230
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
231
+
232
+ # Get the MP3 bytes
233
+ mp3_bytes = mp3_io.getvalue()
234
+ mp3_io.close()
235
+
236
+ return mp3_bytes
237
 
238
  sampling_rate = model.audio_encoder.config.sampling_rate
239
  frame_rate = model.audio_encoder.config.frame_rate
 
261
 
262
  for new_audio in streamer:
263
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
264
+ yield numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
265
 
266
  @spaces.GPU
267
  def generate_jenny(text, description, play_steps_in_s=2.0):
 
364
  with gr.Column():
365
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
366
  description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
367
+ play_seconds = gr.Slider(3.0, 7.0, value=3.0, step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
368
  run_button = gr.Button("Generate Audio", variant="primary")
369
  with gr.Column():
370
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
371
 
372
  inputs = [input_text, description, play_seconds]
373
  outputs = [audio_out]
 
379
  with gr.Column():
380
  input_text = gr.Textbox(label="Input Text", lines=2, value=jenny_examples[0][0], elem_id="input_text")
381
  description = gr.Textbox(label="Description", lines=2, value=jenny_examples[0][1], elem_id="input_description")
382
+ play_seconds = gr.Slider(3.0, 7.0, value=jenny_examples[0][2], step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
383
  run_button = gr.Button("Generate Audio", variant="primary")
384
  with gr.Column():
385
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
386
 
387
  inputs = [input_text, description, play_seconds]
388
  outputs = [audio_out]