File size: 35,162 Bytes
d4bad2d
 
 
 
 
 
 
eb3e66c
d4bad2d
 
98fb2de
4db98bc
98fb2de
bca1c4b
d4bad2d
4db98bc
 
 
 
 
 
 
 
 
 
 
bda8ed2
4db98bc
 
 
 
 
 
 
 
 
bda8ed2
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a204cc2
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda8ed2
 
 
 
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda8ed2
4db98bc
 
 
 
bda8ed2
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4bad2d
bda8ed2
29309b0
 
bda8ed2
 
 
 
 
ee252bb
29309b0
 
 
bda8ed2
 
 
29309b0
 
bda8ed2
29309b0
 
 
 
 
 
 
bda8ed2
29309b0
 
 
 
 
bda8ed2
 
 
 
 
29309b0
 
ad404fa
bda8ed2
 
 
 
 
 
 
 
29309b0
 
d4bad2d
 
 
 
4db98bc
d4bad2d
 
 
 
4db98bc
d4bad2d
 
 
ee252bb
d4bad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db98bc
 
d4bad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db98bc
 
 
 
 
 
d4bad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda8ed2
d4bad2d
bda8ed2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4bad2d
 
 
 
 
 
 
 
 
 
bda8ed2
 
 
d4bad2d
 
 
 
 
 
 
4db98bc
d4bad2d
 
 
 
 
 
4db98bc
 
 
 
d4bad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1deba83
 
 
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda8ed2
4db98bc
d4bad2d
 
4db98bc
 
bda8ed2
 
4db98bc
d4bad2d
ea17fc9
 
 
 
 
 
4db98bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4d4d63
a204cc2
d4bad2d
 
 
4db98bc
d4bad2d
a204cc2
 
d4bad2d
 
 
4db98bc
 
d4bad2d
a0646fa
d4bad2d
afb20ae
4db98bc
afb20ae
4db98bc
 
 
d4bad2d
 
 
 
 
 
bda8ed2
 
 
 
d4bad2d
 
4db98bc
d4bad2d
29309b0
bda8ed2
 
 
 
 
29309b0
 
 
4db98bc
29309b0
a0646fa
d4bad2d
 
29309b0
4db98bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
from queue import Queue
from threading import Thread
from typing import Optional

import numpy as np
import torch

from transformers import MusicgenMelodyForConditionalGeneration, AutoProcessor, set_seed
from transformers.generation.streamers import BaseStreamer

import gradio as gr
import io, wave

import spaces


from transformers import MusicgenMelodyForConditionalGeneration, MusicgenForConditionalGeneration, AutoProcessor, set_seed
from transformers.modeling_outputs import BaseModelOutput
from transformers.utils import logging
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import copy
import torch
import torchaudio

from demucs import pretrained
from demucs.apply import apply_model
from demucs.audio import convert_audio

logger = logging.get_logger(__name__)


class MusicgenMelodyForLongFormConditionalGeneration(MusicgenMelodyForConditionalGeneration):
    stride_longform = 750


    def _prepare_audio_encoder_kwargs_for_longform_generation(
        self, audio_codes, model_kwargs,):
        frames, bsz, codebooks, seq_len = audio_codes.shape

        if frames != 1:
            raise ValueError(
                f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
                "disabled by setting `chunk_length=None` in the audio encoder."
            )

        decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)

        model_kwargs["decoder_input_ids"] = decoder_input_ids
        return model_kwargs
    
    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        synced_gpus: Optional[bool] = None,
        max_longform_generation_length: Optional[int] = 4000,
        streamer: Optional["BaseStreamer"] = None,
        **kwargs,
    ):
        """
        Generates sequences of token ids for models with a language modeling head.
        <Tip warning={true}>
        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
        For an overview of generation strategies and code examples, check out the [following
        guide](./generation_strategies).
        </Tip>
        Parameters:
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and
                generation config. If a logit processor is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            kwargs (`Dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateDecoderOnlyOutput`],
                    - [`~generation.GenerateBeamDecoderOnlyOutput`]
                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        """
        # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
        if generation_config is None:
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
            eos_token_id = generation_config.eos_token_id
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            generation_config.pad_token_id = eos_token_id

        # 3. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

        # 4. Define other model kwargs
        model_kwargs["output_attentions"] = generation_config.output_attentions
        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
        model_kwargs["use_cache"] = generation_config.use_cache
        model_kwargs["guidance_scale"] = generation_config.guidance_scale

        if model_kwargs.get("attention_mask", None) is None:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
            )

        if "encoder_hidden_states" not in model_kwargs:
            # encoder_hidden_states are created and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
                inputs_tensor,
                model_kwargs,
                model_input_name,
                guidance_scale=generation_config.guidance_scale,
            )

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
            batch_size=batch_size,
            model_input_name=model_input_name,
            model_kwargs=model_kwargs,
            decoder_start_token_id=generation_config.decoder_start_token_id,
            bos_token_id=generation_config.bos_token_id,
            device=inputs_tensor.device,
        )

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]

        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            logger.warning(
                f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
                "to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
            )
        elif generation_config.max_new_tokens is not None:
            if not has_default_max_length:
                logger.warning(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
                )
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
            raise ValueError(
                f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
                f" the maximum length ({generation_config.max_length})"
            )
        if input_ids_seq_length >= generation_config.max_length:
            logger.warning(
                f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
        input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
            input_ids,
            pad_token_id=generation_config.decoder_start_token_id,
            max_length=generation_config.max_length,
        )
        # stash the delay mask so that we don't have to recompute in each forward pass
        model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask

        # input_ids are ready to be placed on the streamer (if used)
        if streamer is not None:
            streamer.put(input_ids.cpu())

        # 7. determine generation mode
        is_greedy_gen_mode = (
            (generation_config.num_beams == 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is False
        )
        is_sample_gen_mode = (
            (generation_config.num_beams == 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is True
        )

        # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
        if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
            logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
            generation_config.guidance_scale = None

        # 9. prepare distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=None,
            logits_processor=logits_processor,
        )

        # 10. prepare stopping criteria
        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )

        # ENTER LONGFORM GENERATION LOOP
        generated_tokens = []
        
        # the first timestamps corresponds to decoder_start_token
        current_generated_length = input_ids.shape[1] - 1
        
        max_new_tokens = generation_config.max_new_tokens
        
        while current_generated_length + 20 <= max_longform_generation_length:
            generation_config.max_new_tokens = min(max_new_tokens, max_longform_generation_length - current_generated_length)
            if is_greedy_gen_mode:
                if generation_config.num_return_sequences > 1:
                    raise ValueError(
                        "num_return_sequences has to be 1 when doing greedy search, "
                        f"but is {generation_config.num_return_sequences}."
                    )

                # 11. run greedy search
                outputs = self._greedy_search(
                    input_ids,
                    logits_processor=logits_processor,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=generation_config.pad_token_id,
                    eos_token_id=generation_config.eos_token_id,
                    output_scores=generation_config.output_scores,
                    return_dict_in_generate=generation_config.return_dict_in_generate,
                    synced_gpus=synced_gpus,
                    streamer=streamer,
                    **model_kwargs,
                )

            elif is_sample_gen_mode:
                # 11. prepare logits warper
                logits_warper = self._get_logits_warper(generation_config)

                # expand input_ids with `num_return_sequences` additional sequences per batch
                input_ids, model_kwargs = self._expand_inputs_for_generation(
                    input_ids=input_ids,
                    expand_size=generation_config.num_return_sequences,
                    is_encoder_decoder=self.config.is_encoder_decoder,
                    **model_kwargs,
                )

                # 12. run sample
                outputs = self._sample(
                    input_ids,
                    logits_processor=logits_processor,
                    logits_warper=logits_warper,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=generation_config.pad_token_id,
                    eos_token_id=generation_config.eos_token_id,
                    output_scores=generation_config.output_scores,
                    return_dict_in_generate=generation_config.return_dict_in_generate,
                    synced_gpus=synced_gpus,
                    streamer=streamer,
                    **model_kwargs,
                )

            else:
                raise ValueError(
                    "Got incompatible mode for generation, should be one of greedy or sampling. "
                    "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
                )

            if generation_config.return_dict_in_generate:
                output_ids = outputs.sequences
            else:
                output_ids = outputs

            # apply the pattern mask to the final ids
            output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])

            # revert the pattern delay mask by filtering the pad token id
            output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
                batch_size, self.decoder.num_codebooks, -1
            )
            if len(generated_tokens) >= 1:
                generated_tokens.append(output_ids[:, :, self.stride_longform:])
            else:
                generated_tokens.append(output_ids)
                
            current_generated_length += generated_tokens[-1].shape[-1]
            
            # append the frame dimension back to the audio codes
            # use last generated tokens as begining of the newest generation
            output_ids = output_ids[None, :, :, - self.stride_longform:]
            
            model_kwargs = self._prepare_audio_encoder_kwargs_for_longform_generation(output_ids, model_kwargs)

            # Prepare new `input_ids` which will be used for auto-regressive generation
            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name="input_ids",
                model_kwargs=model_kwargs,
                decoder_start_token_id=self.generation_config.decoder_start_token_id,
                bos_token_id=self.generation_config.bos_token_id,
                device=input_ids.device,
            )
            # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
            input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
                input_ids,
                pad_token_id=generation_config.decoder_start_token_id,
                max_length=generation_config.max_length,
            )
            # stash the delay mask so that we don't have to recompute in each forward pass
            model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
            
            
            # TODO(YL): periodic prompt song

            # encoder_hidden_states are created and added to `model_kwargs`
            # model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
            #     inputs_tensor,
            #     model_kwargs,
            #     model_input_name,
            #     guidance_scale=generation_config.guidance_scale,
            # )
            
        # append the frame dimension back to the audio codes
        output_ids = torch.cat(generated_tokens, dim=-1)[None, ...]
        
        # Specific to this gradio demo
        if streamer is not None:
            streamer.end(True)
            
        audio_scales = model_kwargs.get("audio_scales")
        if audio_scales is None:
            audio_scales = [None] * batch_size

        if self.decoder.config.audio_channels == 1:
            output_values = self.audio_encoder.decode(
                output_ids,
                audio_scales=audio_scales,
            ).audio_values
        else:
            codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
            output_values_left = codec_outputs_left.audio_values

            codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
            output_values_right = codec_outputs_right.audio_values

            output_values = torch.cat([output_values_left, output_values_right], dim=1)

        if generation_config.return_dict_in_generate:
            outputs.sequences = output_values
            return outputs
        else:
            return output_values

model = MusicgenMelodyForLongFormConditionalGeneration.from_pretrained("facebook/musicgen-melody", revision="refs/pr/14")#, attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("facebook/musicgen-melody", revision="refs/pr/14")

demucs = pretrained.get_model('htdemucs')

title = "Streaming Long-form MusicGen"

description = """
Stream the outputs of the MusicGen Melody text-to-music model by playing the generated audio as soon as the first chunk is ready. 

The generation loop is adapted to perform **long-form** music generation. In this demo, we limit the duration of the music generated, but in theory, it could run **endlessly**.

Demo uses [MusicGen Melody](https://huggingface.co/facebook/musicgen-melody) in the 🤗 Transformers library. Note that the 
demo works best on the Chrome browser. If there is no audio output, try switching browser to Chrome.
"""

article = """
## FAQ

### How Does It Work?

MusicGen is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.

At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the 
frame rate of the [EnCodec model](https://huggingface.co/facebook/encodec_32khz) used to decode the generated codes to audio waveform, 
each set of generated audio codes corresponds to 0.02 seconds. This means we require a total of 1000 decoding steps to generate
20 seconds of audio.

Rather than waiting for the entire audio sequence to be generated, which would require the full 1000 decoding steps, we can start 
playing the audio after a specified number of decoding steps have been reached, a techinque known as [*streaming*](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming). 

For example, after 250 steps we have the first 5 seconds of audio ready, and so can play this without waiting for the remaining 
750 decoding steps to be complete. As we continue to generate with the MusicGen model, we append new chunks of generated audio 
to our output waveform on-the-fly. After the full 1000 decoding steps, the generated audio is complete, and is composed of four 
chunks of audio, each corresponding to 250 tokens.

This method of playing incremental generations **reduces the latency** of the MusicGen model from the total time to generate 1000 tokens, 
to the time taken to play the first chunk of audio (250 tokens). This can result in **significant improvements** to perceived latency, 
particularly when the chunk size is chosen to be small. 

In practice, the chunk size should be tuned to your device: using a smaller chunk size will mean that the first chunk is ready faster, but should not be chosen so small that the model generates slower 
than the time it takes to play the audio.

For details on how the streaming class works, check out the source code for the [MusicgenStreamer](https://huggingface.co/spaces/sanchit-gandhi/musicgen-streaming/blob/main/app.py#L52).

### Could this be used for stereo music generation?

In theory, yes, but you would have to adapt the current demo a bit and use a checkpoint specificaly made for stereo generation, for example, this [one](https://huggingface.co/facebook/musicgen-stereo-melody).

### Why is there a delay between the moment the first chunk is generated and the moment the audio starts playing?

This behaviour is specific to gradio and the different components it uses. If you ever adapt this demo for a streaming use-case, you could have lower latency.
"""


class MusicgenStreamer(BaseStreamer):
    def __init__(
        self,
        model: MusicgenMelodyForConditionalGeneration,
        device: Optional[str] = None,
        play_steps: Optional[int] = 10,
        stride: Optional[int] = None,
        timeout: Optional[float] = None,
        is_longform: Optional[bool] = False,
    ):
        """
        Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
        useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
        Gradio demo).
        Parameters:
            model (`MusicgenForConditionalGeneration`):
                The MusicGen model used to generate the audio waveform.
            device (`str`, *optional*):
                The torch device on which to run the computation. If `None`, will default to the device of the model.
            play_steps (`int`, *optional*, defaults to 10):
                The number of generation steps with which to return the generated audio array. Using fewer steps will 
                mean the first chunk is ready faster, but will require more codec decoding steps overall. This value 
                should be tuned to your device and latency requirements.
            stride (`int`, *optional*):
                The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
                the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to 
                play_steps // 6 in the audio space.
            timeout (`int`, *optional*):
                The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
                in `.generate()`, when it is called in a separate thread.
            is_longform (`bool`, *optional*, defaults to `False`):
                If `is_longform`, will takes into account long form stride and non regular ending signal.
        """
        self.decoder = model.decoder
        self.audio_encoder = model.audio_encoder
        self.generation_config = model.generation_config
        self.device = device if device is not None else model.device

        # variables used in the streaming process
        self.play_steps = play_steps
        if stride is not None:
            self.stride = stride
        else:
            hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
            self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
        self.token_cache = None
        self.to_yield = 0
        
        self.is_longform = is_longform
        if is_longform:
            self.longform_stride = model.stride_longform
            self.longform_stride_applied = True
        
        # varibles used in the thread process
        self.audio_queue = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def apply_delay_pattern_mask(self, input_ids):
        # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
        _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
            input_ids[:, :1],
            pad_token_id=self.generation_config.decoder_start_token_id,
            max_length=input_ids.shape[-1],
        )
        # apply the pattern mask to the input ids
        input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)

        # revert the pattern delay mask by filtering the pad token id
        input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
            1, self.decoder.num_codebooks, -1
        )

        # append the frame dimension back to the audio codes
        input_ids = input_ids[None, ...]

        # send the input_ids to the correct device
        input_ids = input_ids.to(self.audio_encoder.device)
        

        if self.decoder.config.audio_channels == 1:
            output_values = self.audio_encoder.decode(
                input_ids,
                audio_scales=[None],
            ).audio_values
        else:
            codec_outputs_left = self.audio_encoder.decode(input_ids[:, :, ::2, :], audio_scales=[None])
            output_values_left = codec_outputs_left.audio_values

            codec_outputs_right = self.audio_encoder.decode(input_ids[:, :, 1::2, :], audio_scales=[None])
            output_values_right = codec_outputs_right.audio_values

            output_values = torch.cat([output_values_left, output_values_right], dim=1)

        audio_values = output_values[0, 0]
        return audio_values.cpu().float().numpy()

    def put(self, value):
        batch_size = value.shape[0] // self.decoder.num_codebooks
        if batch_size > 1:
            raise ValueError("MusicgenStreamer only supports batch size 1")

        if self.token_cache is None:
            self.token_cache = value
        else:
            # if self.is_longform and not self.longform_stride_applied:
            #     value = value[self.longform_stride:]
            #     self.longform_stride_applied = True
            self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)

        if self.token_cache.shape[-1] % self.play_steps == 0:
            audio_values = self.apply_delay_pattern_mask(self.token_cache)
            self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
            self.to_yield += len(audio_values) - self.to_yield - self.stride

    def end(self, stream_end=False):
        """Flushes any remaining cache and appends the stop symbol."""
        if self.token_cache is not None:
            audio_values = self.apply_delay_pattern_mask(self.token_cache)
        else:
            audio_values = np.zeros(self.to_yield)

        stream_end = (not self.is_longform) or stream_end
        if self.is_longform:
            self.longform_stride_applied = False
        self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)

    def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
        """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
        self.audio_queue.put(audio, timeout=self.timeout)
        if stream_end:
            self.audio_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.audio_queue.get(timeout=self.timeout)
        if not isinstance(value, np.ndarray) and value == self.stop_signal:
            raise StopIteration()
        else:
            return value


sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate

target_dtype = np.int16
max_range = np.iinfo(target_dtype).max

def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
    # This will create a wave header then append the frame input
    # It should be first on a streaming wav file
    # Other frames better should not have it (else you will hear some artifacts each chunk start)
    wav_buf = io.BytesIO()
    with wave.open(wav_buf, "wb") as vfout:
        vfout.setnchannels(channels)
        vfout.setsampwidth(sample_width)
        vfout.setframerate(sample_rate)
        vfout.writeframes(frame_input)

    wav_buf.seek(0)
    
    return wav_buf.read()

@spaces.GPU(duration=90)
def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
    max_new_tokens = int(frame_rate * audio_length_in_s)
    play_steps = int(frame_rate * play_steps_in_s)
    
    if audio is not None:
        audio = torchaudio.load(audio)
        audio = convert_audio(audio[0], audio[1], demucs.samplerate, demucs.audio_channels)
        audio = apply_model(demucs, audio[None])

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        if device == "cuda:0":
            model.half()

    if audio is not None:
        inputs = processor(
            text=text_prompt,
            padding=True,
            return_tensors="pt",
            audio=audio, sampling_rate=demucs.samplerate
        )
        if device == "cuda:0":
            inputs["input_features"] = inputs["input_features"].to(torch.float16)
    else:
        inputs = processor(
            text=text_prompt,
            padding=True,
            return_tensors="pt",
        )

    streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True, )

    generation_kwargs = dict(
        **inputs.to(device),
        temperature=1.2,
        streamer=streamer,
        max_new_tokens=min(max_new_tokens, 1500),
        max_longform_generation_length=max_new_tokens,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    yield wave_header_chunk() 

    set_seed(seed)
    for new_audio in streamer:
        print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
        
        new_audio = (new_audio * max_range).astype(np.int16)
        # (sampling_rate, new_audio)
        yield new_audio.tobytes()



demo = gr.Interface(
    fn=generate_audio,
    inputs=[
        gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
        gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
        gr.Slider(35, 60, value=45, step=5, label="(Approximate) Audio length in seconds."),
        gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
        gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
    ],
    outputs=[
        gr.Audio(label="Generated Music", autoplay=True,  interactive=False, streaming=True)
    ],
    examples=[
        ["An 80s driving pop song with heavy drums and synth pads in the background", None, 45, 1.5, 5],
        ["Bossa nova with guitars and synthesizer", None, 45, 1.5, 5],
        ["90s rock song with electric guitar and heavy drums", None, 45, 1.5, 5],
        ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", None, 45, 1.5, 5],
        ["lofi slow bpm electro chill with organic samples", None, 45, 1.5, 5],
    ],
    title=title,
    description=description,
    allow_flagging=False,
    article=article,
    cache_examples=False,
)


demo.queue().launch(debug=True)