radames commited on
Commit
82e1128
1 Parent(s): 45e1876
Files changed (2) hide show
  1. app.py +105 -19
  2. share_btn.py +2 -0
app.py CHANGED
@@ -27,7 +27,16 @@ def load_model(version):
27
 
28
 
29
  def predict(
30
- text, melody_input, duration, continuation, topk, topp, temperature, cfg_coef
 
 
 
 
 
 
 
 
 
31
  ):
32
  global MODEL
33
  topk = int(topk)
@@ -36,8 +45,8 @@ def predict(
36
 
37
  if duration > MODEL.lm.cfg.dataset.segment_duration:
38
  raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
39
- if continuation >= duration:
40
- raise gr.Error("The continuation setting can't be higher or equal to duration!")
41
  MODEL.set_generation_params(
42
  use_sampling=True,
43
  top_k=topk,
@@ -49,18 +58,25 @@ def predict(
49
 
50
  if melody_input:
51
  melody, sr = torchaudio.load(melody_input)
 
 
 
52
  # sr, melody = melody_input[0], torch.from_numpy(melody_input[1]).to(MODEL.device).float().t().unsqueeze(0)
53
  if melody.dim() == 2:
54
  melody = melody[None]
55
  if continuation:
56
- prompt_waveform = melody[..., -int(sr * continuation) :]
 
 
 
57
  output = MODEL.generate_continuation(
58
- prompt=prompt_waveform,
59
  prompt_sample_rate=sr,
60
  descriptions=[text],
61
  progress=True,
62
  )
63
  else:
 
64
  melody_wavform = melody[
65
  ..., : int(sr * MODEL.lm.cfg.dataset.segment_duration)
66
  ]
@@ -71,6 +87,7 @@ def predict(
71
  progress=True,
72
  )
73
  else:
 
74
  output = MODEL.generate(descriptions=[text], progress=False)
75
 
76
  output = output.detach().cpu().float()[0]
@@ -85,7 +102,11 @@ def predict(
85
  add_suffix=False,
86
  )
87
  waveform_video = gr.make_waveform(file.name)
88
- return waveform_video, melody_input
 
 
 
 
89
 
90
 
91
  def ui(**kwargs):
@@ -95,10 +116,40 @@ def ui(**kwargs):
95
  else:
96
  return gr.update(source="upload", value=None, label="File")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  with gr.Blocks(css=css) as interface:
99
  gr.Markdown(
100
  """
101
- # MusicGen
102
  This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
103
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
104
  """
@@ -149,26 +200,46 @@ def ui(**kwargs):
149
  minimum=1,
150
  maximum=30,
151
  value=10,
152
- label="Duration",
153
  interactive=True,
154
  )
155
  with gr.Row():
156
- continuation = gr.Slider(
 
 
157
  minimum=0,
158
  maximum=30,
 
159
  value=0,
160
- label="Continue from the end duration",
161
  interactive=True,
162
  )
163
- with gr.Row():
164
- topk = gr.Number(label="Top-k", value=250, interactive=True)
165
- topp = gr.Number(label="Top-p", value=0, interactive=True)
166
- temperature = gr.Number(
167
- label="Temperature", value=1.0, interactive=True
 
 
168
  )
169
- cfg_coef = gr.Number(
170
- label="Classifier Free Guidance", value=3.0, interactive=True
 
 
 
171
  )
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Column():
173
  output = gr.Video(label="Generated Music", elem_id="generated-video")
174
  output_melody = gr.Audio(label="Melody ", elem_id="melody-output")
@@ -180,6 +251,19 @@ def ui(**kwargs):
180
  "Share to community", elem_id="share-btn"
181
  )
182
  share_button.click(None, [], [], _js=share_js)
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  submit.click(
184
  lambda x: gr.update(visible=False),
185
  None,
@@ -193,6 +277,8 @@ def ui(**kwargs):
193
  melody,
194
  duration,
195
  continuation,
 
 
196
  topk,
197
  topp,
198
  temperature,
@@ -207,7 +293,7 @@ def ui(**kwargs):
207
  show_progress=False,
208
  )
209
  radio.change(toggle, radio, [melody], queue=False, show_progress=False)
210
- gr.Examples(
211
  fn=predict,
212
  examples=[
213
  [
@@ -218,7 +304,7 @@ def ui(**kwargs):
218
  "A cheerful country song with acoustic guitars",
219
  "./assets/bolero_ravel.mp3",
220
  ],
221
- ["90s rock song with electric guitar and heavy drums", None, "medium"],
222
  [
223
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
224
  "./assets/bach.mp3",
 
27
 
28
 
29
  def predict(
30
+ text,
31
+ melody_input,
32
+ duration=30,
33
+ continuation=False,
34
+ continuation_start=0,
35
+ continuation_end=30,
36
+ topk=250,
37
+ topp=0,
38
+ temperature=1,
39
+ cfg_coef=3,
40
  ):
41
  global MODEL
42
  topk = int(topk)
 
45
 
46
  if duration > MODEL.lm.cfg.dataset.segment_duration:
47
  raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
48
+ if continuation and continuation_end < continuation_start:
49
+ raise gr.Error("The end time must be greater than the start time!")
50
  MODEL.set_generation_params(
51
  use_sampling=True,
52
  top_k=topk,
 
58
 
59
  if melody_input:
60
  melody, sr = torchaudio.load(melody_input)
61
+ melody_duration = melody.shape[-1] / sr
62
+ if melody_duration < duration:
63
+ raise gr.Error("The duration must be greater than the melody duration!")
64
  # sr, melody = melody_input[0], torch.from_numpy(melody_input[1]).to(MODEL.device).float().t().unsqueeze(0)
65
  if melody.dim() == 2:
66
  melody = melody[None]
67
  if continuation:
68
+ print("\nGenerating continuation\n")
69
+ melody_wavform = melody[
70
+ ..., int(sr * continuation_start) : int(sr * continuation_end)
71
+ ]
72
  output = MODEL.generate_continuation(
73
+ prompt=melody_wavform,
74
  prompt_sample_rate=sr,
75
  descriptions=[text],
76
  progress=True,
77
  )
78
  else:
79
+ print("\nGenerating with melody\n")
80
  melody_wavform = melody[
81
  ..., : int(sr * MODEL.lm.cfg.dataset.segment_duration)
82
  ]
 
87
  progress=True,
88
  )
89
  else:
90
+ print("\nGenerating without melody\n")
91
  output = MODEL.generate(descriptions=[text], progress=False)
92
 
93
  output = output.detach().cpu().float()[0]
 
102
  add_suffix=False,
103
  )
104
  waveform_video = gr.make_waveform(file.name)
105
+
106
+ return (
107
+ waveform_video,
108
+ (sr, melody_wavform.numpy()) if melody_input else None,
109
+ )
110
 
111
 
112
  def ui(**kwargs):
 
116
  else:
117
  return gr.update(source="upload", value=None, label="File")
118
 
119
+ def check_melody_length(melody_input):
120
+ if not melody_input:
121
+ return gr.update(maximum=0, value=0), gr.update(maximum=0, value=0)
122
+ melody, sr = torchaudio.load(melody_input)
123
+ audio_length = melody.shape[-1] / sr
124
+ if melody.dim() == 2:
125
+ melody = melody[None]
126
+ return gr.update(maximum=audio_length, value=0), gr.update(
127
+ maximum=audio_length, value=audio_length
128
+ )
129
+
130
+ def preview_melody_cut(melody_input, continuation_start, continuation_end):
131
+ if not melody_input:
132
+ return gr.update(maximum=0, value=0), gr.update(maximum=0, value=0)
133
+ melody, sr = torchaudio.load(melody_input)
134
+ audio_length = melody.shape[-1] / sr
135
+ if melody.dim() == 2:
136
+ melody = melody[None]
137
+
138
+ if continuation_end < continuation_start:
139
+ raise gr.Error("The end time must be greater than the start time!")
140
+ if continuation_start < 0 or continuation_end > audio_length:
141
+ raise gr.Error("The continuation settings must be within the audio length!")
142
+ print("cutting", int(sr * continuation_start), int(sr * continuation_end))
143
+ prompt_waveform = melody[
144
+ ..., int(sr * continuation_start) : int(sr * continuation_end)
145
+ ]
146
+
147
+ return (sr, prompt_waveform.numpy())
148
+
149
  with gr.Blocks(css=css) as interface:
150
  gr.Markdown(
151
  """
152
+ # MusicGen Continuation
153
  This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
154
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
155
  """
 
200
  minimum=1,
201
  maximum=30,
202
  value=10,
203
+ label="Total Duration",
204
  interactive=True,
205
  )
206
  with gr.Row():
207
+ continuation = gr.Checkbox(value=False, label="Enable Continuation")
208
+ with gr.Row():
209
+ continuation_start = gr.Slider(
210
  minimum=0,
211
  maximum=30,
212
+ step=0.01,
213
  value=0,
214
+ label="melody cut start",
215
  interactive=True,
216
  )
217
+ continuation_end = gr.Slider(
218
+ minimum=0,
219
+ maximum=30,
220
+ step=0.01,
221
+ value=0,
222
+ label="melody cut end",
223
+ interactive=True,
224
  )
225
+ cut_btn = gr.Button("Cut Melody").style(full_width=False)
226
+ with gr.Row():
227
+ preview_cut = gr.Audio(
228
+ type="numpy",
229
+ label="Cut Preview",
230
  )
231
+ with gr.Accordion(label="Advanced Settings", open=False):
232
+ with gr.Row():
233
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
234
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
235
+ temperature = gr.Number(
236
+ label="Temperature", value=1.0, interactive=True
237
+ )
238
+ cfg_coef = gr.Number(
239
+ label="Classifier Free Guidance",
240
+ value=3.0,
241
+ interactive=True,
242
+ )
243
  with gr.Column():
244
  output = gr.Video(label="Generated Music", elem_id="generated-video")
245
  output_melody = gr.Audio(label="Melody ", elem_id="melody-output")
 
251
  "Share to community", elem_id="share-btn"
252
  )
253
  share_button.click(None, [], [], _js=share_js)
254
+ melody.change(
255
+ check_melody_length,
256
+ melody,
257
+ [continuation_start, continuation_end],
258
+ queue=False,
259
+ )
260
+ cut_btn.click(
261
+ preview_melody_cut,
262
+ [melody, continuation_start, continuation_end],
263
+ preview_cut,
264
+ queue=False,
265
+ )
266
+
267
  submit.click(
268
  lambda x: gr.update(visible=False),
269
  None,
 
277
  melody,
278
  duration,
279
  continuation,
280
+ continuation_start,
281
+ continuation_end,
282
  topk,
283
  topp,
284
  temperature,
 
293
  show_progress=False,
294
  )
295
  radio.change(toggle, radio, [melody], queue=False, show_progress=False)
296
+ examples = gr.Examples(
297
  fn=predict,
298
  examples=[
299
  [
 
304
  "A cheerful country song with acoustic guitars",
305
  "./assets/bolero_ravel.mp3",
306
  ],
307
+ ["90s rock song with electric guitar and heavy drums", None],
308
  [
309
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
310
  "./assets/bach.mp3",
share_btn.py CHANGED
@@ -106,6 +106,8 @@ ${generatedURL}
106
  ${(melodyURL && (typeof melodyURL === 'string'))? `
107
  ### Melody
108
  <audio controls src="${melodyURL}"></audio>` : ``}
 
 
109
  `;
110
  const params = new URLSearchParams({
111
  title: titleTxt,
 
106
  ${(melodyURL && (typeof melodyURL === 'string'))? `
107
  ### Melody
108
  <audio controls src="${melodyURL}"></audio>` : ``}
109
+
110
+ <small>made with continuation: https://huggingface.co/spaces/radames/MusicGen-Continuation</small>
111
  `;
112
  const params = new URLSearchParams({
113
  title: titleTxt,