eusip commited on
Commit
811864c
·
verified ·
1 Parent(s): 9cae843

Upload musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +196 -94
demos/musicgen_app.py CHANGED
@@ -8,30 +8,31 @@
8
  # also released under the MIT license.
9
 
10
  import argparse
11
- from concurrent.futures import ProcessPoolExecutor
12
  import logging
13
  import os
14
- from pathlib import Path
15
  import subprocess as sp
16
  import sys
17
- from tempfile import NamedTemporaryFile
18
  import time
19
  import typing as tp
20
  import warnings
 
 
 
21
 
22
- from einops import rearrange
23
- import torch
24
  import gradio as gr
 
 
25
 
26
- from audiocraft.data.audio_utils import convert_audio
27
  from audiocraft.data.audio import audio_write
 
 
28
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
29
- from audiocraft.models import MusicGen, MultiBandDiffusion
30
-
31
 
32
  MODEL = None # Last used model
33
- SPACE_ID = os.environ.get('SPACE_ID', '')
34
- IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
 
 
35
  print(IS_BATCHED)
36
  MAX_BATCH_SIZE = 12
37
  BATCHED_DURATION = 15
@@ -43,8 +44,8 @@ _old_call = sp.call
43
 
44
  def _call_nostderr(*args, **kwargs):
45
  # Avoid ffmpeg vomiting on the logs.
46
- kwargs['stderr'] = sp.DEVNULL
47
- kwargs['stdout'] = sp.DEVNULL
48
  _old_call(*args, **kwargs)
49
 
50
 
@@ -86,17 +87,19 @@ def make_waveform(*args, **kwargs):
86
  # Further remove some warnings.
87
  be = time.time()
88
  with warnings.catch_warnings():
89
- warnings.simplefilter('ignore')
90
  out = gr.make_waveform(*args, **kwargs)
91
  print("Make a video took", time.time() - be)
92
  return out
93
 
94
 
95
- def load_model(version='facebook/musicgen-melody'):
96
  global MODEL
97
  print("Loading model", version)
98
  if MODEL is None or MODEL.name != version:
 
99
  del MODEL
 
100
  MODEL = None # in case loading would crash
101
  MODEL = MusicGen.get_pretrained(version)
102
 
@@ -108,9 +111,16 @@ def load_diffusion():
108
  MBD = MultiBandDiffusion.get_mbd_musicgen()
109
 
110
 
111
- def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs):
 
 
112
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
113
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
 
 
 
 
 
114
  be = time.time()
115
  processed_melodies = []
116
  target_sr = 32000
@@ -119,10 +129,13 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
119
  if melody is None:
120
  processed_melodies.append(None)
121
  else:
122
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
 
 
 
123
  if melody.dim() == 1:
124
  melody = melody[None]
125
- melody = melody[..., :int(sr * duration)]
126
  melody = convert_audio(melody, sr, target_sr, target_ac)
127
  processed_melodies.append(melody)
128
 
@@ -133,15 +146,17 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
133
  melody_wavs=processed_melodies,
134
  melody_sample_rate=target_sr,
135
  progress=progress,
136
- return_tokens=USE_DIFFUSION
137
  )
138
  else:
139
- outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
 
 
140
  except RuntimeError as e:
141
  raise gr.Error("Error while generating " + e.args[0])
142
  if USE_DIFFUSION:
143
  if gradio_progress is not None:
144
- gradio_progress(1, desc='Running MultiBandDiffusion...')
145
  tokens = outputs[1]
146
  if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
147
  left, right = MODEL.compression_model.get_left_right_codes(tokens)
@@ -149,7 +164,9 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
149
  outputs_diffusion = MBD.tokens_to_wav(tokens)
150
  if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
151
  assert outputs_diffusion.shape[1] == 1 # output is mono
152
- outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
 
 
153
  outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
154
  outputs = outputs.detach().cpu().float()
155
  pending_videos = []
@@ -157,8 +174,14 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
157
  for output in outputs:
158
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
159
  audio_write(
160
- file.name, output, MODEL.sample_rate, strategy="loudness",
161
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
 
 
 
 
 
 
162
  pending_videos.append(pool.submit(make_waveform, file.name))
163
  out_wavs.append(file.name)
164
  file_cleaner.add(file.name)
@@ -173,12 +196,24 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
173
  def predict_batched(texts, melodies):
174
  max_text_length = 512
175
  texts = [text[:max_text_length] for text in texts]
176
- load_model('facebook/musicgen-stereo-melody')
177
  res = _do_predictions(texts, melodies, BATCHED_DURATION)
178
  return res
179
 
180
 
181
- def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
182
  global INTERRUPTING
183
  global USE_DIFFUSION
184
  INTERRUPTING = False
@@ -188,8 +223,10 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp,
188
  if not Path(model_path).exists():
189
  raise gr.Error(f"Model path {model_path} doesn't exist.")
190
  if not Path(model_path).is_dir():
191
- raise gr.Error(f"Model path {model_path} must be a folder containing "
192
- "state_dict.bin and compression_state_dict_.bin.")
 
 
193
  model = model_path
194
  if temperature < 0:
195
  raise gr.Error("Temperature must be >= 0.")
@@ -215,12 +252,20 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp,
215
  progress((min(max_generated, to_generate), to_generate))
216
  if INTERRUPTING:
217
  raise gr.Error("Interrupted.")
 
218
  MODEL.set_custom_progress_callback(_progress)
219
 
220
  videos, wavs = _do_predictions(
221
- [text], [melody], duration, progress=True,
222
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
223
- gradio_progress=progress)
 
 
 
 
 
 
 
224
  if USE_DIFFUSION:
225
  return videos[0], wavs[0], videos[1], wavs[1]
226
  return videos[0], wavs[0], None, None
@@ -255,42 +300,86 @@ def ui_full(launch_kwargs):
255
  with gr.Row():
256
  text = gr.Text(label="Input Text", interactive=True)
257
  with gr.Column():
258
- radio = gr.Radio(["file", "mic"], value="file",
259
- label="Condition on a melody (optional) File or Mic")
260
- melody = gr.Audio(source="upload", type="numpy", label="File",
261
- interactive=True, elem_id="melody-input")
 
 
 
 
 
 
 
 
262
  with gr.Row():
263
  submit = gr.Button("Submit")
264
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
265
  _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
266
  with gr.Row():
267
- model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
268
- "facebook/musicgen-large", "facebook/musicgen-melody-large",
269
- "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
270
- "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
271
- "facebook/musicgen-stereo-melody-large"],
272
- label="Model", value="facebook/musicgen-stereo-melody", interactive=True)
273
  model_path = gr.Text(label="Model Path (custom models)")
274
  with gr.Row():
275
- decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
276
- label="Decoder", value="Default", interactive=True)
 
 
 
 
277
  with gr.Row():
278
- duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
 
 
 
 
 
 
279
  with gr.Row():
280
  topk = gr.Number(label="Top-k", value=250, interactive=True)
281
  topp = gr.Number(label="Top-p", value=0, interactive=True)
282
- temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
283
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
 
 
 
 
284
  with gr.Column():
285
  output = gr.Video(label="Generated Music")
286
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
287
  diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
288
- audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
289
- submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
290
- show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp,
291
- temperature, cfg_coef],
292
- outputs=[output, audio_output, diffusion_output, audio_diffusion])
293
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  gr.Examples(
296
  fn=predict_full,
@@ -299,41 +388,41 @@ def ui_full(launch_kwargs):
299
  "An 80s driving pop song with heavy drums and synth pads in the background",
300
  "./assets/bach.mp3",
301
  "facebook/musicgen-stereo-melody",
302
- "Default"
303
  ],
304
  [
305
  "A cheerful country song with acoustic guitars",
306
  "./assets/bolero_ravel.mp3",
307
  "facebook/musicgen-stereo-melody",
308
- "Default"
309
  ],
310
  [
311
  "90s rock song with electric guitar and heavy drums",
312
  None,
313
  "facebook/musicgen-stereo-medium",
314
- "Default"
315
  ],
316
  [
317
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
318
  "./assets/bach.mp3",
319
  "facebook/musicgen-stereo-melody",
320
- "Default"
321
  ],
322
  [
323
  "lofi slow bpm electro chill with organic samples",
324
  None,
325
  "facebook/musicgen-stereo-medium",
326
- "Default"
327
  ],
328
  [
329
  "Punk rock with loud drum and power guitar",
330
  None,
331
  "facebook/musicgen-stereo-medium",
332
- "MultiBand_Diffusion"
333
  ],
334
  ],
335
  inputs=[text, melody, model, decoder],
336
- outputs=[output]
337
  )
338
  gr.Markdown(
339
  """
@@ -403,20 +492,37 @@ def ui_batched(launch_kwargs):
403
  with gr.Row():
404
  with gr.Column():
405
  with gr.Row():
406
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
 
 
407
  with gr.Column():
408
- radio = gr.Radio(["file", "mic"], value="file",
409
- label="Condition on a melody (optional) File or Mic")
410
- melody = gr.Audio(source="upload", type="numpy", label="File",
411
- interactive=True, elem_id="melody-input")
 
 
 
 
 
 
 
 
412
  with gr.Row():
413
  submit = gr.Button("Generate")
414
  with gr.Column():
415
  output = gr.Video(label="Generated Music")
416
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
417
- submit.click(predict_batched, inputs=[text, melody],
418
- outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
419
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
 
 
 
 
 
 
 
420
  gr.Examples(
421
  fn=predict_batched,
422
  examples=[
@@ -442,7 +548,7 @@ def ui_batched(launch_kwargs):
442
  ],
443
  ],
444
  inputs=[text, melody],
445
- outputs=[output]
446
  )
447
  gr.Markdown("""
448
  ### More details
@@ -476,50 +582,46 @@ def ui_batched(launch_kwargs):
476
  if __name__ == "__main__":
477
  parser = argparse.ArgumentParser()
478
  parser.add_argument(
479
- '--listen',
480
  type=str,
481
- default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
482
- help='IP to listen on for connections to Gradio',
483
  )
484
  parser.add_argument(
485
- '--username', type=str, default='', help='Username for authentication'
486
  )
487
  parser.add_argument(
488
- '--password', type=str, default='', help='Password for authentication'
489
  )
490
  parser.add_argument(
491
- '--server_port',
492
  type=int,
493
  default=0,
494
- help='Port to run the server listener on',
495
- )
496
- parser.add_argument(
497
- '--inbrowser', action='store_true', help='Open in browser'
498
- )
499
- parser.add_argument(
500
- '--share', action='store_true', help='Share the gradio UI'
501
  )
 
 
502
 
503
  args = parser.parse_args()
504
 
505
  launch_kwargs = {}
506
- launch_kwargs['server_name'] = args.listen
507
 
508
  if args.username and args.password:
509
- launch_kwargs['auth'] = (args.username, args.password)
510
  if args.server_port:
511
- launch_kwargs['server_port'] = args.server_port
512
  if args.inbrowser:
513
- launch_kwargs['inbrowser'] = args.inbrowser
514
  if args.share:
515
- launch_kwargs['share'] = args.share
516
 
517
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
518
 
519
  # Show the interface
520
- if IS_BATCHED:
521
- global USE_DIFFUSION
522
- USE_DIFFUSION = False
523
- ui_batched(launch_kwargs)
524
- else:
525
- ui_full(launch_kwargs)
 
8
  # also released under the MIT license.
9
 
10
  import argparse
 
11
  import logging
12
  import os
 
13
  import subprocess as sp
14
  import sys
 
15
  import time
16
  import typing as tp
17
  import warnings
18
+ from concurrent.futures import ProcessPoolExecutor
19
+ from pathlib import Path
20
+ from tempfile import NamedTemporaryFile
21
 
 
 
22
  import gradio as gr
23
+ import torch
24
+ from einops import rearrange
25
 
 
26
  from audiocraft.data.audio import audio_write
27
+ from audiocraft.data.audio_utils import convert_audio
28
+ from audiocraft.models import MultiBandDiffusion, MusicGen
29
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
 
 
30
 
31
  MODEL = None # Last used model
32
+ SPACE_ID = os.environ.get("SPACE_ID", "")
33
+ IS_BATCHED = (
34
+ "facebook/MusicGen" in SPACE_ID or "musicgen-internal/musicgen_dev" in SPACE_ID
35
+ )
36
  print(IS_BATCHED)
37
  MAX_BATCH_SIZE = 12
38
  BATCHED_DURATION = 15
 
44
 
45
  def _call_nostderr(*args, **kwargs):
46
  # Avoid ffmpeg vomiting on the logs.
47
+ kwargs["stderr"] = sp.DEVNULL
48
+ kwargs["stdout"] = sp.DEVNULL
49
  _old_call(*args, **kwargs)
50
 
51
 
 
87
  # Further remove some warnings.
88
  be = time.time()
89
  with warnings.catch_warnings():
90
+ warnings.simplefilter("ignore")
91
  out = gr.make_waveform(*args, **kwargs)
92
  print("Make a video took", time.time() - be)
93
  return out
94
 
95
 
96
+ def load_model(version="facebook/musicgen-melody"):
97
  global MODEL
98
  print("Loading model", version)
99
  if MODEL is None or MODEL.name != version:
100
+ # Clear PyTorch CUDA cache and delete model
101
  del MODEL
102
+ torch.cuda.empty_cache()
103
  MODEL = None # in case loading would crash
104
  MODEL = MusicGen.get_pretrained(version)
105
 
 
111
  MBD = MultiBandDiffusion.get_mbd_musicgen()
112
 
113
 
114
+ def _do_predictions(
115
+ texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs
116
+ ):
117
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
118
+ print(
119
+ "new batch",
120
+ len(texts),
121
+ texts,
122
+ [None if m is None else (m[0], m[1].shape) for m in melodies],
123
+ )
124
  be = time.time()
125
  processed_melodies = []
126
  target_sr = 32000
 
129
  if melody is None:
130
  processed_melodies.append(None)
131
  else:
132
+ sr, melody = (
133
+ melody[0],
134
+ torch.from_numpy(melody[1]).to(MODEL.device).float().t(),
135
+ )
136
  if melody.dim() == 1:
137
  melody = melody[None]
138
+ melody = melody[..., : int(sr * duration)]
139
  melody = convert_audio(melody, sr, target_sr, target_ac)
140
  processed_melodies.append(melody)
141
 
 
146
  melody_wavs=processed_melodies,
147
  melody_sample_rate=target_sr,
148
  progress=progress,
149
+ return_tokens=USE_DIFFUSION,
150
  )
151
  else:
152
+ outputs = MODEL.generate(
153
+ texts, progress=progress, return_tokens=USE_DIFFUSION
154
+ )
155
  except RuntimeError as e:
156
  raise gr.Error("Error while generating " + e.args[0])
157
  if USE_DIFFUSION:
158
  if gradio_progress is not None:
159
+ gradio_progress(1, desc="Running MultiBandDiffusion...")
160
  tokens = outputs[1]
161
  if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
162
  left, right = MODEL.compression_model.get_left_right_codes(tokens)
 
164
  outputs_diffusion = MBD.tokens_to_wav(tokens)
165
  if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
166
  assert outputs_diffusion.shape[1] == 1 # output is mono
167
+ outputs_diffusion = rearrange(
168
+ outputs_diffusion, "(s b) c t -> b (s c) t", s=2
169
+ )
170
  outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
171
  outputs = outputs.detach().cpu().float()
172
  pending_videos = []
 
174
  for output in outputs:
175
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
176
  audio_write(
177
+ file.name,
178
+ output,
179
+ MODEL.sample_rate,
180
+ strategy="loudness",
181
+ loudness_headroom_db=16,
182
+ loudness_compressor=True,
183
+ add_suffix=False,
184
+ )
185
  pending_videos.append(pool.submit(make_waveform, file.name))
186
  out_wavs.append(file.name)
187
  file_cleaner.add(file.name)
 
196
  def predict_batched(texts, melodies):
197
  max_text_length = 512
198
  texts = [text[:max_text_length] for text in texts]
199
+ load_model("facebook/musicgen-stereo-melody")
200
  res = _do_predictions(texts, melodies, BATCHED_DURATION)
201
  return res
202
 
203
 
204
+ def predict_full(
205
+ model,
206
+ model_path,
207
+ decoder,
208
+ text,
209
+ melody,
210
+ duration,
211
+ topk,
212
+ topp,
213
+ temperature,
214
+ cfg_coef,
215
+ progress=gr.Progress(),
216
+ ):
217
  global INTERRUPTING
218
  global USE_DIFFUSION
219
  INTERRUPTING = False
 
223
  if not Path(model_path).exists():
224
  raise gr.Error(f"Model path {model_path} doesn't exist.")
225
  if not Path(model_path).is_dir():
226
+ raise gr.Error(
227
+ f"Model path {model_path} must be a folder containing "
228
+ "state_dict.bin and compression_state_dict_.bin."
229
+ )
230
  model = model_path
231
  if temperature < 0:
232
  raise gr.Error("Temperature must be >= 0.")
 
252
  progress((min(max_generated, to_generate), to_generate))
253
  if INTERRUPTING:
254
  raise gr.Error("Interrupted.")
255
+
256
  MODEL.set_custom_progress_callback(_progress)
257
 
258
  videos, wavs = _do_predictions(
259
+ [text],
260
+ [melody],
261
+ duration,
262
+ progress=True,
263
+ top_k=topk,
264
+ top_p=topp,
265
+ temperature=temperature,
266
+ cfg_coef=cfg_coef,
267
+ gradio_progress=progress,
268
+ )
269
  if USE_DIFFUSION:
270
  return videos[0], wavs[0], videos[1], wavs[1]
271
  return videos[0], wavs[0], None, None
 
300
  with gr.Row():
301
  text = gr.Text(label="Input Text", interactive=True)
302
  with gr.Column():
303
+ radio = gr.Radio(
304
+ ["file", "mic"],
305
+ value="file",
306
+ label="Condition on a melody (optional) File or Mic",
307
+ )
308
+ melody = gr.Audio(
309
+ sources=["upload"],
310
+ type="numpy",
311
+ label="File",
312
+ interactive=True,
313
+ elem_id="melody-input",
314
+ )
315
  with gr.Row():
316
  submit = gr.Button("Submit")
317
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
318
  _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
319
  with gr.Row():
320
+ # model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
321
+ # "facebook/musicgen-large", "facebook/musicgen-melody-large",
322
+ # "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
323
+ # "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
324
+ # "facebook/musicgen-stereo-melody-large"],
325
+ # label="Model", value="facebook/musicgen-stereo-melody", interactive=True)
326
  model_path = gr.Text(label="Model Path (custom models)")
327
  with gr.Row():
328
+ decoder = gr.Radio(
329
+ ["Default", "MultiBand_Diffusion"],
330
+ label="Decoder",
331
+ value="Default",
332
+ interactive=True,
333
+ )
334
  with gr.Row():
335
+ duration = gr.Slider(
336
+ minimum=1,
337
+ maximum=60,
338
+ value=10,
339
+ label="Duration",
340
+ interactive=True,
341
+ )
342
  with gr.Row():
343
  topk = gr.Number(label="Top-k", value=250, interactive=True)
344
  topp = gr.Number(label="Top-p", value=0, interactive=True)
345
+ temperature = gr.Number(
346
+ label="Temperature", value=1.0, interactive=True
347
+ )
348
+ cfg_coef = gr.Number(
349
+ label="Classifier Free Guidance", value=3.0, interactive=True
350
+ )
351
  with gr.Column():
352
  output = gr.Video(label="Generated Music")
353
+ audio_output = gr.Audio(label="Generated Music (wav)", type="filepath")
354
  diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
355
+ audio_diffusion = gr.Audio(
356
+ label="MultiBand Diffusion Decoder (wav)", type="filepath"
357
+ )
358
+ submit.click(
359
+ toggle_diffusion,
360
+ decoder,
361
+ [diffusion_output, audio_diffusion],
362
+ queue=False,
363
+ show_progress=False,
364
+ ).then(
365
+ predict_full,
366
+ inputs=[
367
+ model,
368
+ model_path,
369
+ decoder,
370
+ text,
371
+ melody,
372
+ duration,
373
+ topk,
374
+ topp,
375
+ temperature,
376
+ cfg_coef,
377
+ ],
378
+ outputs=[output, audio_output, diffusion_output, audio_diffusion],
379
+ )
380
+ radio.change(
381
+ toggle_audio_src, radio, [melody], queue=False, show_progress=False
382
+ )
383
 
384
  gr.Examples(
385
  fn=predict_full,
 
388
  "An 80s driving pop song with heavy drums and synth pads in the background",
389
  "./assets/bach.mp3",
390
  "facebook/musicgen-stereo-melody",
391
+ "Default",
392
  ],
393
  [
394
  "A cheerful country song with acoustic guitars",
395
  "./assets/bolero_ravel.mp3",
396
  "facebook/musicgen-stereo-melody",
397
+ "Default",
398
  ],
399
  [
400
  "90s rock song with electric guitar and heavy drums",
401
  None,
402
  "facebook/musicgen-stereo-medium",
403
+ "Default",
404
  ],
405
  [
406
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
407
  "./assets/bach.mp3",
408
  "facebook/musicgen-stereo-melody",
409
+ "Default",
410
  ],
411
  [
412
  "lofi slow bpm electro chill with organic samples",
413
  None,
414
  "facebook/musicgen-stereo-medium",
415
+ "Default",
416
  ],
417
  [
418
  "Punk rock with loud drum and power guitar",
419
  None,
420
  "facebook/musicgen-stereo-medium",
421
+ "MultiBand_Diffusion",
422
  ],
423
  ],
424
  inputs=[text, melody, model, decoder],
425
+ outputs=[output],
426
  )
427
  gr.Markdown(
428
  """
 
492
  with gr.Row():
493
  with gr.Column():
494
  with gr.Row():
495
+ text = gr.Text(
496
+ label="Describe your music", lines=2, interactive=True
497
+ )
498
  with gr.Column():
499
+ radio = gr.Radio(
500
+ ["file", "mic"],
501
+ value="file",
502
+ label="Condition on a melody (optional) File or Mic",
503
+ )
504
+ melody = gr.Audio(
505
+ source="upload",
506
+ type="numpy",
507
+ label="File",
508
+ interactive=True,
509
+ elem_id="melody-input",
510
+ )
511
  with gr.Row():
512
  submit = gr.Button("Generate")
513
  with gr.Column():
514
  output = gr.Video(label="Generated Music")
515
+ audio_output = gr.Audio(label="Generated Music (wav)", type="filepath")
516
+ submit.click(
517
+ predict_batched,
518
+ inputs=[text, melody],
519
+ outputs=[output, audio_output],
520
+ batch=True,
521
+ max_batch_size=MAX_BATCH_SIZE,
522
+ )
523
+ radio.change(
524
+ toggle_audio_src, radio, [melody], queue=False, show_progress=False
525
+ )
526
  gr.Examples(
527
  fn=predict_batched,
528
  examples=[
 
548
  ],
549
  ],
550
  inputs=[text, melody],
551
+ outputs=[output],
552
  )
553
  gr.Markdown("""
554
  ### More details
 
582
  if __name__ == "__main__":
583
  parser = argparse.ArgumentParser()
584
  parser.add_argument(
585
+ "--listen",
586
  type=str,
587
+ default="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
588
+ help="IP to listen on for connections to Gradio",
589
  )
590
  parser.add_argument(
591
+ "--username", type=str, default="", help="Username for authentication"
592
  )
593
  parser.add_argument(
594
+ "--password", type=str, default="", help="Password for authentication"
595
  )
596
  parser.add_argument(
597
+ "--server_port",
598
  type=int,
599
  default=0,
600
+ help="Port to run the server listener on",
 
 
 
 
 
 
601
  )
602
+ parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
603
+ parser.add_argument("--share", action="store_true", help="Share the gradio UI")
604
 
605
  args = parser.parse_args()
606
 
607
  launch_kwargs = {}
608
+ launch_kwargs["server_name"] = args.listen
609
 
610
  if args.username and args.password:
611
+ launch_kwargs["auth"] = (args.username, args.password)
612
  if args.server_port:
613
+ launch_kwargs["server_port"] = args.server_port
614
  if args.inbrowser:
615
+ launch_kwargs["inbrowser"] = args.inbrowser
616
  if args.share:
617
+ launch_kwargs["share"] = args.share
618
 
619
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
620
 
621
  # Show the interface
622
+ # if IS_BATCHED:
623
+ # global USE_DIFFUSION
624
+ # USE_DIFFUSION = False
625
+ # ui_batched(launch_kwargs)
626
+ # else:
627
+ ui_full(launch_kwargs)