Z commited on
Commit
b7920e6
1 Parent(s): 0fca9d6
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +111 -0
  3. audio.py +58 -0
  4. generator.py +31 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ static/
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, json
3
+ from generator import HijackedMusicGen
4
+ from audiocraft.data.audio import audio_write
5
+ from audio import predict
6
+ from itertools import zip_longest
7
+
8
+ def split_prompt(bigly_prompt, num_segments):
9
+ prompts = bigly_prompt.split(',,')
10
+ num_segments = int(num_segments) # Assuming 'segment' comes as a string from Gradio slider
11
+ # repeat last prompt to fill in the rest
12
+ if len(prompts) < num_segments:
13
+ prompts += [prompts[-1]] * (num_segments - len(prompts))
14
+ elif len(prompts) > num_segments:
15
+ prompts = prompts[:num_segments]
16
+ return prompts
17
+
18
+ loaded_model = None
19
+ audio_files = []
20
+ def model_interface(model_name, top_k, top_p, temperature, cfg_coef, segments, overlap, duration, optional_audio, prompt):
21
+ global loaded_model
22
+
23
+ if loaded_model is None or loaded_model.name != model_name:
24
+ loaded_model = HijackedMusicGen.get_pretrained(None, name=model_name)
25
+
26
+ print(optional_audio)
27
+
28
+ loaded_model.set_generation_params(
29
+ use_sampling=True,
30
+ duration=duration,
31
+ top_p=top_p,
32
+ top_k=top_k,
33
+ temperature=temperature,
34
+ cfg_coef=cfg_coef,
35
+ )
36
+
37
+ extension_parameters = {"segments":segments, "overlap":overlap}
38
+
39
+ prompts = split_prompt(prompt, segments)
40
+ first_prompt = prompts[0]
41
+ sample_rate, audio = predict(loaded_model, prompts, None, extension_parameters)
42
+ counter = 1
43
+ audio_path = "static/"
44
+ audio_name = first_prompt
45
+ while os.path.exists(audio_path + audio_name + ".wav"):
46
+ audio_name = f"{first_prompt}({counter})"
47
+ counter += 1
48
+
49
+ file = audio_write(audio_path + audio_name, audio.squeeze(), sample_rate, strategy="loudness")
50
+ audio_files.append(file)
51
+
52
+ audio_list_html = "<br>".join([
53
+ f'''
54
+ <div style="border:1px solid #000; padding:10px; margin-bottom:10px;">
55
+ <div>{os.path.splitext(os.path.basename(file))[0]}</div>
56
+ <audio controls><source src="/file={file}" type="audio/wav"></audio>
57
+ </div>
58
+ '''
59
+ for file in reversed(audio_files)
60
+ ])
61
+
62
+ return audio_list_html
63
+
64
+ slider_param = {
65
+ "top_k": {"minimum": 0, "maximum": 1000, "value": 0, "label": "Top K"},
66
+ "top_p": {"minimum": 0.0, "maximum": 1.0, "value": 0.0, "label": "Top P"},
67
+ "temperature": {"minimum": 0.1, "maximum": 10.0, "value": 1.0, "label": "Temperature"},
68
+ "cfg_coef": {"minimum": 0.0, "maximum": 10.0, "value": 4.0, "label": "CFG Coefficient"},
69
+ "segments": {"minimum": 1, "maximum": 10, "value": 1, "step": 1, "label": "Number of Segments"},
70
+ "overlap": {"minimum": 0.0, "maximum": 10.0, "value": 1.0, "label": "Segment Overlap"},
71
+ "duration": {"minimum": 1, "maximum": 300, "value": 10, "label": "Duration"},
72
+ }
73
+
74
+ slider_params = {
75
+ key: gr.components.Slider(**params)
76
+ for key, params in slider_param.items()
77
+ }
78
+
79
+ with gr.Blocks() as interface:
80
+ with gr.Row():
81
+
82
+ with gr.Column():
83
+ with gr.Row():
84
+ model_dropdown = gr.components.Dropdown(choices=["small", "medium", "large", "melody"], label="Model Size", value="large")
85
+ optional_audio = gr.compoents.Audio(source="upload", type="filepath", label="Optional Audio", interactive=True)
86
+
87
+ slider_keys = list(slider_param.keys())
88
+ slider_pairs = list(zip_longest(slider_keys[::2], slider_keys[1::2]))
89
+
90
+ for key1, key2 in slider_pairs:
91
+ with gr.Row():
92
+ with gr.Column():
93
+ slider_params[key1] = gr.components.Slider(**slider_param[key1])
94
+ if key2 is not None:
95
+ with gr.Column():
96
+ slider_params[key2] = gr.components.Slider(**slider_param[key2])
97
+
98
+ prompt_box = gr.components.Textbox(lines=5, placeholder="""Insert a double comma ,, to indicate this should prompt a new segment. For example:
99
+ Rock Opera,,Dueling Banjos
100
+ This allows you to prompt each segment individually. If you only provide one prompt, every segment will use that one prompt. If you provide multiple prompts but less than the number of segments, then the last prompt will be used to fill in the rest.
101
+ """)
102
+ submit = gr.Button("Submit")
103
+
104
+ with gr.Column():
105
+ output = gr.outputs.HTML()
106
+
107
+ inputs_list = [model_dropdown] + list(slider_params.values()) + [optional_audio] + [prompt_box]
108
+ submit.click(model_interface, inputs=inputs_list, outputs=[output])
109
+
110
+ interface.queue()
111
+ interface.launch()
audio.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, re, json, sys
3
+ import torch, torchaudio, pathlib
4
+
5
+ def load_and_process_audio(model, melody, sample_rate):
6
+ if melody is not None:
7
+ melody = torch.from_numpy(melody).to(model.device).float().t().unsqueeze(0)
8
+ if melody.dim() == 2:
9
+ melody = melody[None]
10
+ melody = melody[..., :int(sample_rate * model.lm.cfg.dataset.segment_duration)]
11
+ return melody
12
+ else:
13
+ return None
14
+
15
+ #From https://colab.research.google.com/drive/154CqogsdP-D_TfSF9S2z8-BY98GN_na4?usp=sharing#scrollTo=exKxNU_Z4i5I
16
+ #Thank you DragonForged for the link
17
+ def extend_audio(model, prompt_waveform, prompts, prompt_sr, segments=5, overlap=2):
18
+ # Calculate the number of samples corresponding to the overlap
19
+ overlap_samples = int(overlap * prompt_sr)
20
+
21
+ device = model.device
22
+ prompt_waveform = prompt_waveform.to(device)
23
+
24
+ for i in range(1, segments):
25
+ # Grab the end of the waveform
26
+ end_waveform = prompt_waveform[...,-overlap_samples:]
27
+
28
+ # Process the trimmed waveform using the model
29
+ new_audio = model.generate_continuation(end_waveform, descriptions=[prompts[i]], prompt_sample_rate=prompt_sr, progress=True)
30
+
31
+ # Cut the seed audio off the newly generated audio
32
+ new_audio = new_audio[...,overlap_samples:]
33
+
34
+ prompt_waveform = torch.cat([prompt_waveform, new_audio], dim=2)
35
+
36
+ return prompt_waveform
37
+
38
+ def predict(model, prompts, melody_parameters, extension_parameters):
39
+ melody = None #load_and_process_audio(MODEL, **melody_parameters)
40
+
41
+ if melody is not None:
42
+ output = MODEL.generate_with_chroma(
43
+ descriptions=[prompt[0]],
44
+ melody_wavs=melody,
45
+ melody_sample_rate=melody_parameters['sample_rate'],
46
+ progress=False
47
+ )
48
+ else:
49
+ output = model.generate(descriptions=[prompts[0]], progress=True)
50
+
51
+ sample_rate = model.sample_rate
52
+
53
+ if extension_parameters['segments'] > 1:
54
+ output_tensors = extend_audio(model, output, prompts, sample_rate, **extension_parameters).detach().cpu().float()
55
+ else:
56
+ output_tensors = output.detach().cpu().float()
57
+
58
+ return sample_rate, output_tensors
generator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import typing as tp
4
+ from audiocraft.models import MusicGen
5
+ from audiocraft.modules.conditioners import ConditioningAttributes
6
+
7
+ class HijackedMusicGen(MusicGen):
8
+ def __init__(self, socketio=None, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.socketio = socketio
11
+ self._progress_callback = self._timed_progress_callback if socketio is not None else None
12
+ self._last_update_time = time.time()
13
+
14
+ def _timed_progress_callback(self, generated_tokens: int, tokens_to_generate: int):
15
+ current_time = time.time()
16
+ if current_time - self._last_update_time >= 0.1: # 0.1 seconds have passed
17
+ self.socketio.emit('progress', {'generated_tokens': generated_tokens, 'tokens_to_generate': tokens_to_generate})
18
+ self._last_update_time = current_time
19
+
20
+ @staticmethod
21
+ def get_pretrained(socketio, name: str = 'melody', device='cuda'):
22
+ music_gen = MusicGen.get_pretrained(name, device)
23
+ return HijackedMusicGen(socketio, music_gen.name, music_gen.compression_model, music_gen.lm)
24
+
25
+ @property
26
+ def progress_callback(self):
27
+ raise Exception("Progress callback is write-only")
28
+
29
+ @progress_callback.setter
30
+ def progress_callback(self, callback):
31
+ self._progress_callback = callback