ylacombe commited on
Commit
4db98bc
·
verified ·
1 Parent(s): 06cf1a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +477 -24
app.py CHANGED
@@ -9,11 +9,413 @@ from transformers import MusicgenMelodyForConditionalGeneration, AutoProcessor,
9
  from transformers.generation.streamers import BaseStreamer
10
 
11
  import gradio as gr
12
- import spaces
13
 
 
14
 
15
- model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
16
- processor = AutoProcessor.from_pretrained("facebook/musicgen-melody")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  title = "MusicGen Streaming"
19
 
@@ -52,11 +454,12 @@ For details on how the streaming class works, check out the source code for the
52
  class MusicgenStreamer(BaseStreamer):
53
  def __init__(
54
  self,
55
- model: MusicgenForConditionalGeneration,
56
  device: Optional[str] = None,
57
  play_steps: Optional[int] = 10,
58
  stride: Optional[int] = None,
59
  timeout: Optional[float] = None,
 
60
  ):
61
  """
62
  Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
@@ -79,6 +482,8 @@ class MusicgenStreamer(BaseStreamer):
79
  timeout (`int`, *optional*):
80
  The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
81
  in `.generate()`, when it is called in a separate thread.
 
 
82
  """
83
  self.decoder = model.decoder
84
  self.audio_encoder = model.audio_encoder
@@ -94,7 +499,12 @@ class MusicgenStreamer(BaseStreamer):
94
  self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
95
  self.token_cache = None
96
  self.to_yield = 0
97
-
 
 
 
 
 
98
  # varibles used in the thread process
99
  self.audio_queue = Queue()
100
  self.stop_signal = None
@@ -140,17 +550,24 @@ class MusicgenStreamer(BaseStreamer):
140
 
141
  if self.token_cache.shape[-1] % self.play_steps == 0:
142
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
 
 
 
 
143
  self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
144
  self.to_yield += len(audio_values) - self.to_yield - self.stride
145
 
146
- def end(self):
147
  """Flushes any remaining cache and appends the stop symbol."""
148
  if self.token_cache is not None:
149
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
150
  else:
151
  audio_values = np.zeros(self.to_yield)
152
 
153
- self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
 
 
 
154
 
155
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
156
  """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
@@ -175,11 +592,29 @@ frame_rate = model.audio_encoder.config.frame_rate
175
  target_dtype = np.int16
176
  max_range = np.iinfo(target_dtype).max
177
 
178
-
179
- @spaces.GPU()
180
- def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  max_new_tokens = int(frame_rate * audio_length_in_s)
182
  play_steps = int(frame_rate * play_steps_in_s)
 
 
 
 
183
 
184
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
185
  if device != model.device:
@@ -187,52 +622,70 @@ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, see
187
  if device == "cuda:0":
188
  model.half()
189
 
190
- inputs = processor(
191
- text=text_prompt,
192
- padding=True,
193
- return_tensors="pt",
194
- )
 
 
 
 
 
 
 
 
 
 
195
 
196
- streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
197
 
198
  generation_kwargs = dict(
199
  **inputs.to(device),
 
200
  streamer=streamer,
201
  max_new_tokens=max_new_tokens,
202
  )
203
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
204
  thread.start()
 
 
205
 
206
  set_seed(seed)
207
  for new_audio in streamer:
208
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
 
209
  new_audio = (new_audio * max_range).astype(np.int16)
210
- yield sampling_rate, new_audio
 
 
211
 
212
 
213
  demo = gr.Interface(
214
  fn=generate_audio,
215
  inputs=[
216
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
 
217
  gr.Slider(10, 30, value=15, step=5, label="Audio length in seconds"),
218
  gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
219
  gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
220
  ],
221
  outputs=[
222
- gr.Audio(label="Generated Music", streaming=True, autoplay=True)
223
  ],
224
  examples=[
225
- ["An 80s driving pop song with heavy drums and synth pads in the background", 30, 1.5, 5],
226
- ["A cheerful country song with acoustic guitars", 30, 1.5, 5],
227
- ["90s rock song with electric guitar and heavy drums", 30, 1.5, 5],
228
- ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", 30, 1.5, 5],
229
- ["lofi slow bpm electro chill with organic samples", 30, 1.5, 5],
230
  ],
231
  title=title,
232
  description=description,
 
233
  article=article,
234
  cache_examples=False,
235
  )
236
 
237
 
238
- demo.queue().launch()
 
9
  from transformers.generation.streamers import BaseStreamer
10
 
11
  import gradio as gr
12
+ import io, wave
13
 
14
+ # import spaces
15
 
16
+
17
+ from transformers import MusicgenMelodyForConditionalGeneration, MusicgenForConditionalGeneration, AutoProcessor, set_seed
18
+ from transformers.modeling_outputs import BaseModelOutput
19
+ from transformers.utils import logging
20
+ from transformers.generation.configuration_utils import GenerationConfig
21
+ from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
22
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
23
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
24
+
25
+ import copy
26
+ import torch
27
+ import inspect
28
+
29
+ from demucs import pretrained
30
+ from demucs.apply import apply_model
31
+ from demucs.audio import convert_audio
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class MusicgenMelodyForLongFormConditionalGeneration(MusicgenMelodyForConditionalGeneration):
37
+ stride_longform = 500
38
+ max_longform_generation_length = 4000
39
+
40
+
41
+ def _prepare_audio_encoder_kwargs_for_longform_generation(
42
+ self, audio_codes, model_kwargs,):
43
+ frames, bsz, codebooks, seq_len = audio_codes.shape
44
+
45
+ if frames != 1:
46
+ raise ValueError(
47
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
48
+ "disabled by setting `chunk_length=None` in the audio encoder."
49
+ )
50
+
51
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
52
+
53
+ model_kwargs["decoder_input_ids"] = decoder_input_ids
54
+ return model_kwargs
55
+
56
+ @torch.no_grad()
57
+ def generate(
58
+ self,
59
+ inputs: Optional[torch.Tensor] = None,
60
+ generation_config: Optional[GenerationConfig] = None,
61
+ logits_processor: Optional[LogitsProcessorList] = None,
62
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
63
+ synced_gpus: Optional[bool] = None,
64
+ streamer: Optional["BaseStreamer"] = None,
65
+ **kwargs,
66
+ ):
67
+ """
68
+
69
+ Generates sequences of token ids for models with a language modeling head.
70
+
71
+ <Tip warning={true}>
72
+
73
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
74
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
75
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
76
+
77
+ For an overview of generation strategies and code examples, check out the [following
78
+ guide](./generation_strategies).
79
+
80
+ </Tip>
81
+
82
+ Parameters:
83
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
84
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
85
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
86
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
87
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
88
+ generation_config (`~generation.GenerationConfig`, *optional*):
89
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
90
+ passed to generate matching the attributes of `generation_config` will override them. If
91
+ `generation_config` is not provided, the default will be used, which had the following loading
92
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
93
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
94
+ default values, whose documentation should be checked to parameterize generation.
95
+ logits_processor (`LogitsProcessorList`, *optional*):
96
+ Custom logits processors that complement the default logits processors built from arguments and
97
+ generation config. If a logit processor is passed that is already created with the arguments or a
98
+ generation config an error is thrown. This feature is intended for advanced users.
99
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
100
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
101
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
102
+ generation config an error is thrown. This feature is intended for advanced users.
103
+ synced_gpus (`bool`, *optional*, defaults to `False`):
104
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
105
+ streamer (`BaseStreamer`, *optional*):
106
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
107
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
108
+ kwargs (`Dict[str, Any]`, *optional*):
109
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
110
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
111
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
112
+
113
+ Return:
114
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
115
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
116
+
117
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
118
+ [`~utils.ModelOutput`] types are:
119
+
120
+ - [`~generation.GenerateDecoderOnlyOutput`],
121
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
122
+
123
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
124
+ [`~utils.ModelOutput`] types are:
125
+
126
+ - [`~generation.GenerateEncoderDecoderOutput`],
127
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
128
+ """
129
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
130
+ if generation_config is None:
131
+ generation_config = self.generation_config
132
+
133
+ generation_config = copy.deepcopy(generation_config)
134
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
135
+ generation_config.validate()
136
+ self._validate_model_kwargs(model_kwargs.copy())
137
+
138
+ # 2. Set generation parameters if not already defined
139
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
140
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
141
+
142
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
143
+ if model_kwargs.get("attention_mask", None) is None:
144
+ logger.warning(
145
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
146
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
147
+ )
148
+ eos_token_id = generation_config.eos_token_id
149
+ if isinstance(eos_token_id, list):
150
+ eos_token_id = eos_token_id[0]
151
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
152
+ generation_config.pad_token_id = eos_token_id
153
+
154
+ # 3. Define model inputs
155
+ # inputs_tensor has to be defined
156
+ # model_input_name is defined if model-specific keyword input is passed
157
+ # otherwise model_input_name is None
158
+ # all model-specific keyword inputs are removed from `model_kwargs`
159
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
160
+ inputs, generation_config.bos_token_id, model_kwargs
161
+ )
162
+ batch_size = inputs_tensor.shape[0]
163
+
164
+ # 4. Define other model kwargs
165
+ model_kwargs["output_attentions"] = generation_config.output_attentions
166
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
167
+ model_kwargs["use_cache"] = generation_config.use_cache
168
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
169
+
170
+ if model_kwargs.get("attention_mask", None) is None:
171
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
172
+ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
173
+ )
174
+
175
+ if "encoder_hidden_states" not in model_kwargs:
176
+ # encoder_hidden_states are created and added to `model_kwargs`
177
+ model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
178
+ inputs_tensor,
179
+ model_kwargs,
180
+ model_input_name,
181
+ guidance_scale=generation_config.guidance_scale,
182
+ )
183
+
184
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
185
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
186
+ batch_size=batch_size,
187
+ model_input_name=model_input_name,
188
+ model_kwargs=model_kwargs,
189
+ decoder_start_token_id=generation_config.decoder_start_token_id,
190
+ bos_token_id=generation_config.bos_token_id,
191
+ device=inputs_tensor.device,
192
+ )
193
+
194
+ # 6. Prepare `max_length` depending on other stopping criteria.
195
+ input_ids_seq_length = input_ids.shape[-1]
196
+
197
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
198
+ if has_default_max_length and generation_config.max_new_tokens is None:
199
+ logger.warning(
200
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
201
+ "to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
202
+ )
203
+ elif generation_config.max_new_tokens is not None:
204
+ if not has_default_max_length:
205
+ logger.warning(
206
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
207
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
208
+ "Please refer to the documentation for more information. "
209
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
210
+ )
211
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
212
+
213
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
214
+ raise ValueError(
215
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
216
+ f" the maximum length ({generation_config.max_length})"
217
+ )
218
+ if input_ids_seq_length >= generation_config.max_length:
219
+ logger.warning(
220
+ f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
221
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
222
+ " increasing `max_new_tokens`."
223
+ )
224
+
225
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
226
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
227
+ input_ids,
228
+ pad_token_id=generation_config.decoder_start_token_id,
229
+ max_length=generation_config.max_length,
230
+ )
231
+ # stash the delay mask so that we don't have to recompute in each forward pass
232
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
233
+
234
+ # input_ids are ready to be placed on the streamer (if used)
235
+ if streamer is not None:
236
+ streamer.put(input_ids.cpu())
237
+
238
+ # 7. determine generation mode
239
+ is_greedy_gen_mode = (
240
+ (generation_config.num_beams == 1)
241
+ and (generation_config.num_beam_groups == 1)
242
+ and generation_config.do_sample is False
243
+ )
244
+ is_sample_gen_mode = (
245
+ (generation_config.num_beams == 1)
246
+ and (generation_config.num_beam_groups == 1)
247
+ and generation_config.do_sample is True
248
+ )
249
+
250
+ # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
251
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
252
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
253
+ generation_config.guidance_scale = None
254
+
255
+ # 9. prepare distribution pre_processing samplers
256
+ logits_processor = self._get_logits_processor(
257
+ generation_config=generation_config,
258
+ input_ids_seq_length=input_ids_seq_length,
259
+ encoder_input_ids=inputs_tensor,
260
+ prefix_allowed_tokens_fn=None,
261
+ logits_processor=logits_processor,
262
+ )
263
+
264
+ # 10. prepare stopping criteria
265
+ stopping_criteria = self._get_stopping_criteria(
266
+ generation_config=generation_config, stopping_criteria=stopping_criteria
267
+ )
268
+
269
+ # ENTER LONGFORM GENERATION LOOP
270
+ generated_tokens = []
271
+
272
+ # the first timestamps corresponds to decoder_start_token
273
+ current_generated_length = input_ids.shape[1] - 1
274
+
275
+ while current_generated_length <= self.max_longform_generation_length:
276
+ if is_greedy_gen_mode:
277
+ if generation_config.num_return_sequences > 1:
278
+ raise ValueError(
279
+ "num_return_sequences has to be 1 when doing greedy search, "
280
+ f"but is {generation_config.num_return_sequences}."
281
+ )
282
+
283
+ # 11. run greedy search
284
+ outputs = self._greedy_search(
285
+ input_ids,
286
+ logits_processor=logits_processor,
287
+ stopping_criteria=stopping_criteria,
288
+ pad_token_id=generation_config.pad_token_id,
289
+ eos_token_id=generation_config.eos_token_id,
290
+ output_scores=generation_config.output_scores,
291
+ return_dict_in_generate=generation_config.return_dict_in_generate,
292
+ synced_gpus=synced_gpus,
293
+ streamer=streamer,
294
+ **model_kwargs,
295
+ )
296
+
297
+ elif is_sample_gen_mode:
298
+ # 11. prepare logits warper
299
+ logits_warper = self._get_logits_warper(generation_config)
300
+
301
+ # expand input_ids with `num_return_sequences` additional sequences per batch
302
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
303
+ input_ids=input_ids,
304
+ expand_size=generation_config.num_return_sequences,
305
+ is_encoder_decoder=self.config.is_encoder_decoder,
306
+ **model_kwargs,
307
+ )
308
+
309
+ # 12. run sample
310
+ outputs = self._sample(
311
+ input_ids,
312
+ logits_processor=logits_processor,
313
+ logits_warper=logits_warper,
314
+ stopping_criteria=stopping_criteria,
315
+ pad_token_id=generation_config.pad_token_id,
316
+ eos_token_id=generation_config.eos_token_id,
317
+ output_scores=generation_config.output_scores,
318
+ return_dict_in_generate=generation_config.return_dict_in_generate,
319
+ synced_gpus=synced_gpus,
320
+ streamer=streamer,
321
+ **model_kwargs,
322
+ )
323
+
324
+ else:
325
+ raise ValueError(
326
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
327
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
328
+ )
329
+
330
+ if generation_config.return_dict_in_generate:
331
+ output_ids = outputs.sequences
332
+ else:
333
+ output_ids = outputs
334
+
335
+ # apply the pattern mask to the final ids
336
+ output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
337
+
338
+ # revert the pattern delay mask by filtering the pad token id
339
+ output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
340
+ batch_size, self.decoder.num_codebooks, -1
341
+ )
342
+ if len(generated_tokens) >= 1:
343
+ generated_tokens.append(output_ids[:, :, self.stride_longform:])
344
+ else:
345
+ generated_tokens.append(output_ids)
346
+
347
+ current_generated_length += generated_tokens[-1].shape[-1]
348
+
349
+ # append the frame dimension back to the audio codes
350
+ # use last generated tokens as begining of the newest generation
351
+ output_ids = output_ids[None, :, :, (output_ids.shape[-1] - self.stride_longform):]
352
+
353
+ model_kwargs = self._prepare_audio_encoder_kwargs_for_longform_generation(output_ids, model_kwargs)
354
+
355
+ # Prepare new `input_ids` which will be used for auto-regressive generation
356
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
357
+ batch_size=batch_size,
358
+ model_input_name="input_ids",
359
+ model_kwargs=model_kwargs,
360
+ decoder_start_token_id=self.generation_config.decoder_start_token_id,
361
+ bos_token_id=self.generation_config.bos_token_id,
362
+ device=input_ids.device,
363
+ )
364
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
365
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
366
+ input_ids,
367
+ pad_token_id=generation_config.decoder_start_token_id,
368
+ max_length=generation_config.max_length,
369
+ )
370
+ # stash the delay mask so that we don't have to recompute in each forward pass
371
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
372
+
373
+
374
+ # TODO(YL): periodic prompt song
375
+
376
+ # encoder_hidden_states are created and added to `model_kwargs`
377
+ # model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
378
+ # inputs_tensor,
379
+ # model_kwargs,
380
+ # model_input_name,
381
+ # guidance_scale=generation_config.guidance_scale,
382
+ # )
383
+
384
+ # append the frame dimension back to the audio codes
385
+ output_ids = torch.cat(generated_tokens, dim=-1)[None, ...]
386
+
387
+ # Specific to this gradio demo
388
+ if streamer is not None:
389
+ streamer.end(True)
390
+
391
+ audio_scales = model_kwargs.get("audio_scales")
392
+ if audio_scales is None:
393
+ audio_scales = [None] * batch_size
394
+
395
+ if self.decoder.config.audio_channels == 1:
396
+ output_values = self.audio_encoder.decode(
397
+ output_ids,
398
+ audio_scales=audio_scales,
399
+ ).audio_values
400
+ else:
401
+ codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
402
+ output_values_left = codec_outputs_left.audio_values
403
+
404
+ codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
405
+ output_values_right = codec_outputs_right.audio_values
406
+
407
+ output_values = torch.cat([output_values_left, output_values_right], dim=1)
408
+
409
+ if generation_config.return_dict_in_generate:
410
+ outputs.sequences = output_values
411
+ return outputs
412
+ else:
413
+ return output_values
414
+
415
+ model = MusicgenMelodyForLongFormConditionalGeneration.from_pretrained("facebook/musicgen-melody", revision="refs/pr/14")#, attn_implementation="sdpa")
416
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-melody", revision="refs/pr/14")
417
+
418
+ demucs = pretrained.get_model('htdemucs')
419
 
420
  title = "MusicGen Streaming"
421
 
 
454
  class MusicgenStreamer(BaseStreamer):
455
  def __init__(
456
  self,
457
+ model: MusicgenMelodyForConditionalGeneration,
458
  device: Optional[str] = None,
459
  play_steps: Optional[int] = 10,
460
  stride: Optional[int] = None,
461
  timeout: Optional[float] = None,
462
+ is_longform: Optional[bool] = False,
463
  ):
464
  """
465
  Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
 
482
  timeout (`int`, *optional*):
483
  The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
484
  in `.generate()`, when it is called in a separate thread.
485
+ is_longform (`bool`, *optional*, defaults to `False`):
486
+ If `is_longform`, will takes into account long form stride and non regular ending signal.
487
  """
488
  self.decoder = model.decoder
489
  self.audio_encoder = model.audio_encoder
 
499
  self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
500
  self.token_cache = None
501
  self.to_yield = 0
502
+
503
+ self.is_longform = is_longform
504
+ if is_longform:
505
+ self.longform_stride = model.stride_longform
506
+ self.longform_stride_applied = True
507
+
508
  # varibles used in the thread process
509
  self.audio_queue = Queue()
510
  self.stop_signal = None
 
550
 
551
  if self.token_cache.shape[-1] % self.play_steps == 0:
552
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
553
+ if self.is_longform:
554
+ if not self.longform_stride_applied:
555
+ self.to_yield = self.to_yield + self.longform_stride
556
+ self.longform_stride_applied = True
557
  self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
558
  self.to_yield += len(audio_values) - self.to_yield - self.stride
559
 
560
+ def end(self, stream_end=False):
561
  """Flushes any remaining cache and appends the stop symbol."""
562
  if self.token_cache is not None:
563
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
564
  else:
565
  audio_values = np.zeros(self.to_yield)
566
 
567
+ stream_end = (not self.is_longform) or stream_end
568
+ if self.is_longform:
569
+ self.longform_stride_applied = False
570
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
571
 
572
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
573
  """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
 
592
  target_dtype = np.int16
593
  max_range = np.iinfo(target_dtype).max
594
 
595
+ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
596
+ # This will create a wave header then append the frame input
597
+ # It should be first on a streaming wav file
598
+ # Other frames better should not have it (else you will hear some artifacts each chunk start)
599
+ wav_buf = io.BytesIO()
600
+ with wave.open(wav_buf, "wb") as vfout:
601
+ vfout.setnchannels(channels)
602
+ vfout.setsampwidth(sample_width)
603
+ vfout.setframerate(sample_rate)
604
+ vfout.writeframes(frame_input)
605
+
606
+ wav_buf.seek(0)
607
+
608
+ return wav_buf.read()
609
+
610
+ # @spaces.GPU()
611
+ def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
612
  max_new_tokens = int(frame_rate * audio_length_in_s)
613
  play_steps = int(frame_rate * play_steps_in_s)
614
+
615
+ if audio is not None:
616
+ audio = convert_audio(torch.tensor(audio[1]).float(), audio[0], demucs.samplerate, demucs.audio_channels)
617
+ audio = apply_model(demucs, audio[None])
618
 
619
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
620
  if device != model.device:
 
622
  if device == "cuda:0":
623
  model.half()
624
 
625
+ if audio is not None:
626
+ inputs = processor(
627
+ text=text_prompt,
628
+ padding=True,
629
+ return_tensors="pt",
630
+ audio=audio, sampling_rate=demucs.samplerate
631
+ )
632
+ if device == "cuda:0":
633
+ inputs["input_features"] = inputs["input_features"].to(torch.float16)
634
+ else:
635
+ inputs = processor(
636
+ text=text_prompt,
637
+ padding=True,
638
+ return_tensors="pt",
639
+ )
640
 
641
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True)
642
 
643
  generation_kwargs = dict(
644
  **inputs.to(device),
645
+ temperature=1.2,
646
  streamer=streamer,
647
  max_new_tokens=max_new_tokens,
648
  )
649
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
650
  thread.start()
651
+
652
+ yield wave_header_chunk()
653
 
654
  set_seed(seed)
655
  for new_audio in streamer:
656
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
657
+
658
  new_audio = (new_audio * max_range).astype(np.int16)
659
+ # (sampling_rate, new_audio)
660
+ yield new_audio.tobytes()
661
+
662
 
663
 
664
  demo = gr.Interface(
665
  fn=generate_audio,
666
  inputs=[
667
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
668
+ gr.Audio(source="upload", type="numpy", label="Conditioning audio"),
669
  gr.Slider(10, 30, value=15, step=5, label="Audio length in seconds"),
670
  gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
671
  gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
672
  ],
673
  outputs=[
674
+ gr.Audio(label="Generated Music", autoplay=True, interactive=False, streaming=True)
675
  ],
676
  examples=[
677
+ ["An 80s driving pop song with heavy drums and synth pads in the background", None, 30, 1.5, 5],
678
+ ["A cheerful country song with acoustic guitars", None, 30, 1.5, 5],
679
+ ["90s rock song with electric guitar and heavy drums", None, 30, 1.5, 5],
680
+ ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", None, 30, 1.5, 5],
681
+ ["lofi slow bpm electro chill with organic samples", None, 30, 1.5, 5],
682
  ],
683
  title=title,
684
  description=description,
685
+ allow_flagging=False,
686
  article=article,
687
  cache_examples=False,
688
  )
689
 
690
 
691
+ demo.queue().launch(debug=True)