ylacombe commited on
Commit
52cf258
·
verified ·
1 Parent(s): c12ffff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -263,7 +263,7 @@ class MusicgenMelodyForLongFormConditionalGeneration(MusicgenMelodyForConditiona
263
 
264
  max_new_tokens = generation_config.max_new_tokens
265
 
266
- while current_generated_length + 20 <= max_longform_generation_length:
267
  generation_config.max_new_tokens = min(max_new_tokens, max_longform_generation_length - current_generated_length)
268
  if is_greedy_gen_mode:
269
  if generation_config.num_return_sequences > 1:
@@ -378,7 +378,7 @@ class MusicgenMelodyForLongFormConditionalGeneration(MusicgenMelodyForConditiona
378
 
379
  # Specific to this gradio demo
380
  if streamer is not None:
381
- streamer.end(True)
382
 
383
  audio_scales = model_kwargs.get("audio_scales")
384
  if audio_scales is None:
@@ -414,7 +414,7 @@ title = "Streaming Long-form MusicGen"
414
  description = """
415
  Stream the outputs of the MusicGen Melody text-to-music model by playing the generated audio as soon as the first chunk is ready.
416
 
417
- The generation loop is adapted to perform **long-form** music generation. In this demo, we limit the duration of the music generated, but in theory, it could run **endlessly**.
418
 
419
  Demo uses [MusicGen Melody](https://huggingface.co/facebook/musicgen-melody) in the 🤗 Transformers library. Note that the
420
  demo works best on the Chrome browser. If there is no audio output, try switching browser to Chrome.
@@ -468,6 +468,7 @@ class MusicgenStreamer(BaseStreamer):
468
  stride: Optional[int] = None,
469
  timeout: Optional[float] = None,
470
  is_longform: Optional[bool] = False,
 
471
  ):
472
  """
473
  Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
@@ -496,6 +497,7 @@ class MusicgenStreamer(BaseStreamer):
496
  self.audio_encoder = model.audio_encoder
497
  self.generation_config = model.generation_config
498
  self.device = device if device is not None else model.device
 
499
 
500
  # variables used in the streaming process
501
  self.play_steps = play_steps
@@ -509,6 +511,8 @@ class MusicgenStreamer(BaseStreamer):
509
 
510
  self.is_longform = is_longform
511
 
 
 
512
  # varibles used in the thread process
513
  self.audio_queue = Queue()
514
  self.stop_signal = None
@@ -565,19 +569,19 @@ class MusicgenStreamer(BaseStreamer):
565
 
566
  if self.token_cache.shape[-1] % self.play_steps == 0:
567
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
568
- if self.to_yield != len(audio_values) - self.stride:
569
- self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
570
- self.to_yield += len(audio_values) - self.to_yield - self.stride
571
 
572
- def end(self, stream_end=False):
573
  """Flushes any remaining cache and appends the stop symbol."""
574
  if self.token_cache is not None:
575
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
576
  else:
577
  audio_values = np.zeros(self.to_yield)
578
 
579
- stream_end = (not self.is_longform) or stream_end
580
- self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
581
 
582
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
583
  """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
@@ -618,8 +622,10 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=2
618
  return wav_buf.read()
619
 
620
  @spaces.GPU(duration=90)
621
- def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
 
622
  max_new_tokens = int(frame_rate * audio_length_in_s)
 
623
  play_steps = int(frame_rate * play_steps_in_s)
624
 
625
  if audio is not None:
@@ -649,7 +655,8 @@ def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2
649
  return_tensors="pt",
650
  )
651
 
652
- streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True)
 
653
 
654
  generation_kwargs = dict(
655
  **inputs.to(device),
@@ -678,19 +685,17 @@ demo = gr.Interface(
678
  inputs=[
679
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
680
  gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
681
- gr.Slider(30, 60, value=45, step=5, label="(Approximate) Audio length in seconds."),
682
- gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
683
  gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
684
  ],
685
  outputs=[
686
  gr.Audio(label="Generated Music", autoplay=True, interactive=False, streaming=True)
687
  ],
688
  examples=[
689
- ["An 80s driving pop song with heavy drums and synth pads in the background", None, 45, 1.5, 5],
690
- ["Bossa nova with guitars and synthesizer", "./assets/assets_bolero_ravel.mp3", 45, 1.5, 5],
691
- ["90s rock song with electric guitar and heavy drums", "./assets/assets_bach.mp3", 45, 1.5, 5],
692
- ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", None, 45, 1.5, 5],
693
- ["lofi slow bpm electro chill with organic samples", None, 45, 1.5, 5],
694
  ],
695
  title=title,
696
  description=description,
 
263
 
264
  max_new_tokens = generation_config.max_new_tokens
265
 
266
+ while current_generated_length + 4 <= max_longform_generation_length:
267
  generation_config.max_new_tokens = min(max_new_tokens, max_longform_generation_length - current_generated_length)
268
  if is_greedy_gen_mode:
269
  if generation_config.num_return_sequences > 1:
 
378
 
379
  # Specific to this gradio demo
380
  if streamer is not None:
381
+ streamer.end(final_end=True)
382
 
383
  audio_scales = model_kwargs.get("audio_scales")
384
  if audio_scales is None:
 
414
  description = """
415
  Stream the outputs of the MusicGen Melody text-to-music model by playing the generated audio as soon as the first chunk is ready.
416
 
417
+ The generation loop is adapted to perform **long-form** music generation. In this demo, we limit the duration of the music generated to 1mn20, but in theory, it could run **endlessly**.
418
 
419
  Demo uses [MusicGen Melody](https://huggingface.co/facebook/musicgen-melody) in the 🤗 Transformers library. Note that the
420
  demo works best on the Chrome browser. If there is no audio output, try switching browser to Chrome.
 
468
  stride: Optional[int] = None,
469
  timeout: Optional[float] = None,
470
  is_longform: Optional[bool] = False,
471
+ longform_stride: Optional[float] = 10,
472
  ):
473
  """
474
  Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
 
497
  self.audio_encoder = model.audio_encoder
498
  self.generation_config = model.generation_config
499
  self.device = device if device is not None else model.device
500
+ self.longform_stride = longform_stride
501
 
502
  # variables used in the streaming process
503
  self.play_steps = play_steps
 
511
 
512
  self.is_longform = is_longform
513
 
514
+ self.previous_len = -1
515
+
516
  # varibles used in the thread process
517
  self.audio_queue = Queue()
518
  self.stop_signal = None
 
569
 
570
  if self.token_cache.shape[-1] % self.play_steps == 0:
571
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
572
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
573
+ self.to_yield = len(audio_values) - self.stride
574
+ self.previous_len = len(audio_values)
575
 
576
+ def end(self, stream_end=False, final_end=False):
577
  """Flushes any remaining cache and appends the stop symbol."""
578
  if self.token_cache is not None:
579
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
580
  else:
581
  audio_values = np.zeros(self.to_yield)
582
 
583
+ if final_end:
584
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
585
 
586
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
587
  """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
 
622
  return wav_buf.read()
623
 
624
  @spaces.GPU(duration=90)
625
+ def generate_audio(text_prompt, audio, seed=0):
626
+ audio_length_in_s = 60
627
  max_new_tokens = int(frame_rate * audio_length_in_s)
628
+ play_steps_in_s = 2.0
629
  play_steps = int(frame_rate * play_steps_in_s)
630
 
631
  if audio is not None:
 
655
  return_tensors="pt",
656
  )
657
 
658
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True,
659
+ longform_stride=15*32000)
660
 
661
  generation_kwargs = dict(
662
  **inputs.to(device),
 
685
  inputs=[
686
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
687
  gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
 
 
688
  gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
689
  ],
690
  outputs=[
691
  gr.Audio(label="Generated Music", autoplay=True, interactive=False, streaming=True)
692
  ],
693
  examples=[
694
+ ["An 80s driving pop song with heavy drums and synth pads in the background", None, 5],
695
+ ["Bossa nova with guitars and synthesizer", "./assets/assets_bolero_ravel.mp3", 5],
696
+ ["90s rock song with electric guitar and heavy drums", "./assets/assets_bach.mp3", 5],
697
+ ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", None, 5],
698
+ ["lofi slow bpm electro chill with organic samples", None, 5],
699
  ],
700
  title=title,
701
  description=description,