Fabrice-TIERCELIN commited on
Commit
297d013
·
verified ·
1 Parent(s): ae6f193

Warn for wrong parameters without queuing

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +30 -14
demos/musicgen_app.py CHANGED
@@ -178,6 +178,15 @@ def predict_batched(texts, melodies):
178
  return res
179
 
180
 
 
 
 
 
 
 
 
 
 
181
  def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, output_hint, progress=gr.Progress()):
182
  global INTERRUPTING
183
  global USE_DIFFUSION
@@ -191,12 +200,6 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp,
191
  raise gr.Error(f"Model path {model_path} must be a folder containing "
192
  "state_dict.bin and compression_state_dict_.bin.")
193
  model = model_path
194
- if temperature < 0:
195
- raise gr.Error("Temperature must be >= 0.")
196
- if topk < 0:
197
- raise gr.Error("Topk must be non-negative.")
198
- if topp < 0:
199
- raise gr.Error("Topp must be non-negative.")
200
 
201
  topk = int(topk)
202
  if decoder == "MultiBand_Diffusion":
@@ -286,10 +289,10 @@ def ui_full(launch_kwargs):
286
  duration = gr.Slider(label = "Duration", info = "(in seconds)", minimum = 1, maximum = 120, value = 30, interactive = True)
287
  with gr.Accordion("Advanced options", open = False):
288
  with gr.Row():
289
- topk = gr.Number(label = "Top-k", info = "Number of tokens shortlisted", value = 250, interactive = True)
290
- topp = gr.Number(label = "Top-p", info = "Percent of tokens shortlisted", value = 0, interactive = True)
291
  temperature = gr.Number(label = "Temperature", info = "lower=Always similar, higher=More creative", value = 1.0, interactive = True)
292
- cfg_coef = gr.Number(label = "Classifier-Free Guidance", info = "lower=Audio quality, higher=Follow the prompt", value = 3.0, interactive = True)
293
  with gr.Row():
294
  decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
295
  label = "Decoder", value = "Default", interactive = True)
@@ -312,12 +315,25 @@ def ui_full(launch_kwargs):
312
  output_hint = gr.Label(label = "Information")
313
  diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
314
  audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
 
315
  submit.click(toggle_diffusion, decoder, [
316
  diffusion_output,
317
  audio_diffusion
318
- ], queue=False, show_progress=False).then(hide_information, decoder, [
319
  output_hint
320
- ], queue=False, show_progress=False).then(predict_full, inputs = [
 
 
 
 
 
 
 
 
 
 
 
 
321
  model,
322
  model_path,
323
  decoder,
@@ -337,13 +353,13 @@ def ui_full(launch_kwargs):
337
  audio_diffusion
338
  ], scroll_to_output = True).then(show_information, decoder, [
339
  output_hint
340
- ], queue=False, show_progress=False)
341
 
342
  radio.change(toggle_audio_src, radio, [melody], queue = False, show_progress = False)
343
 
344
  gr.Examples(
345
- fn=predict_full,
346
- examples=[
347
  [
348
  "An angry propulsive industrial score with distorted synthesizers and tortured vocals.",
349
  None,
 
178
  return res
179
 
180
 
181
+ def check(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, output_hint):
182
+ if temperature < 0:
183
+ raise gr.Error("Temperature must not be negative.")
184
+ if topk < 0:
185
+ raise gr.Error("Topk must not be negative.")
186
+ if topp < 0:
187
+ raise gr.Error("Topp must not be negative.")
188
+
189
+
190
  def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, output_hint, progress=gr.Progress()):
191
  global INTERRUPTING
192
  global USE_DIFFUSION
 
200
  raise gr.Error(f"Model path {model_path} must be a folder containing "
201
  "state_dict.bin and compression_state_dict_.bin.")
202
  model = model_path
 
 
 
 
 
 
203
 
204
  topk = int(topk)
205
  if decoder == "MultiBand_Diffusion":
 
289
  duration = gr.Slider(label = "Duration", info = "(in seconds)", minimum = 1, maximum = 120, value = 30, interactive = True)
290
  with gr.Accordion("Advanced options", open = False):
291
  with gr.Row():
292
+ topk = gr.Number(label = "Top-k", info = "Number of tokens shortlisted", value = 250, minimum = 0, interactive = True)
293
+ topp = gr.Number(label = "Top-p", info = "Percent of tokens shortlisted", value = 0, minimum = 0, interactive = True)
294
  temperature = gr.Number(label = "Temperature", info = "lower=Always similar, higher=More creative", value = 1.0, interactive = True)
295
+ cfg_coef = gr.Number(label = "Classifier-Free Guidance", info = "lower=Audio quality, higher=Follow the prompt", value = 3.0, minimum = 1, interactive = True)
296
  with gr.Row():
297
  decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
298
  label = "Decoder", value = "Default", interactive = True)
 
315
  output_hint = gr.Label(label = "Information")
316
  diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
317
  audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
318
+
319
  submit.click(toggle_diffusion, decoder, [
320
  diffusion_output,
321
  audio_diffusion
322
+ ], queue = False, show_progress = False).then(hide_information, decoder, [
323
  output_hint
324
+ ], queue = False, show_progress = False).then(check, inputs = [
325
+ model,
326
+ model_path,
327
+ decoder,
328
+ text,
329
+ melody,
330
+ duration,
331
+ topk,
332
+ topp,
333
+ temperature,
334
+ cfg_coef,
335
+ output_hint
336
+ ], outputs = [], queue = False, show_progress = False).success(predict_full, inputs = [
337
  model,
338
  model_path,
339
  decoder,
 
353
  audio_diffusion
354
  ], scroll_to_output = True).then(show_information, decoder, [
355
  output_hint
356
+ ], queue = False, show_progress = False)
357
 
358
  radio.change(toggle_audio_src, radio, [melody], queue = False, show_progress = False)
359
 
360
  gr.Examples(
361
+ fn = predict_full,
362
+ examples = [
363
  [
364
  "An angry propulsive industrial score with distorted synthesizers and tortured vocals.",
365
  None,