skytnt commited on
Commit
1fd2f8b
·
1 Parent(s): cd46b84

speed up rendering audio

Browse files
Files changed (2) hide show
  1. app.py +41 -9
  2. midi_synthesizer.py +6 -2
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import spaces
2
  import random
3
  import argparse
@@ -240,7 +242,8 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr
240
 
241
  def finish_run(model_name, mid_seq):
242
  if mid_seq is None:
243
- return None, None, []
 
244
  tokenizer = models[model_name].tokenizer
245
  outputs = []
246
  end_msgs = [create_msg("progress", [0, 0])]
@@ -249,16 +252,36 @@ def finish_run(model_name, mid_seq):
249
  for i in range(OUTPUT_BATCH_SIZE):
250
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
251
  mid = tokenizer.detokenize(mid_seq[i])
252
- audio = synthesizer.synthesis(MIDI.score2opus(mid))
253
  with open(f"outputs/output{i + 1}.mid", 'wb') as f:
254
  f.write(MIDI.score2midi(mid))
255
- outputs += [(44100, audio), f"outputs/output{i + 1}.mid"]
256
  end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
257
  create_msg("visualizer_append", [i, events]),
258
  create_msg("visualizer_end", i)]
259
  return *outputs, send_msgs(end_msgs)
260
 
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def undo_continuation(model_name, mid_seq, continuation_state):
263
  if mid_seq is None or len(continuation_state) < 2:
264
  return mid_seq, continuation_state, send_msgs([])
@@ -324,6 +347,7 @@ if __name__ == "__main__":
324
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
325
  opt = parser.parse_args()
326
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
 
327
  synthesizer = MidiSynthesizer(soundfont_path)
328
  models_info = {
329
  "generic pretrain model (tv2o-medium) by skytnt": ["skytnt/midi-model-tv2o-medium", "", "tv2o-medium"],
@@ -442,20 +466,23 @@ if __name__ == "__main__":
442
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
443
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=30)
444
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
 
445
  example3 = gr.Examples([[1, 0.95, 128], [1, 0.98, 20], [1, 0.98, 12]],
446
  [input_temp, input_top_p, input_top_k])
447
  run_btn = gr.Button("generate", variant="primary")
448
  # stop_btn = gr.Button("stop and output")
449
  output_midi_seq = gr.State()
450
  output_continuation_state = gr.State([0])
451
- batch_outputs = []
 
452
  with gr.Tabs(elem_id="output_tabs"):
453
  for i in range(OUTPUT_BATCH_SIZE):
454
  with gr.TabItem(f"output {i + 1}") as tab1:
455
  output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
456
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
457
  output_midi = gr.File(label="output midi", file_types=[".mid"])
458
- batch_outputs += [output_audio, output_midi]
 
459
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
460
  input_continuation_select, input_instruments, input_drum_kit, input_bpm,
461
  input_time_sig, input_key_sig, input_midi, input_midi_events,
@@ -465,12 +492,17 @@ if __name__ == "__main__":
465
  input_top_k, input_allow_cc],
466
  [output_midi_seq, output_continuation_state, input_seed, js_msg],
467
  concurrency_limit=10, queue=True)
468
- run_event.then(fn=finish_run,
469
- inputs=[input_model, output_midi_seq],
470
- outputs=batch_outputs + [js_msg],
471
- queue=False)
 
 
 
 
472
  # stop_btn.click(None, [], [], cancels=run_event,
473
  # queue=False)
474
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
475
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
476
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+
3
  import spaces
4
  import random
5
  import argparse
 
242
 
243
  def finish_run(model_name, mid_seq):
244
  if mid_seq is None:
245
+ outputs = [None] * OUTPUT_BATCH_SIZE
246
+ return *outputs, []
247
  tokenizer = models[model_name].tokenizer
248
  outputs = []
249
  end_msgs = [create_msg("progress", [0, 0])]
 
252
  for i in range(OUTPUT_BATCH_SIZE):
253
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
254
  mid = tokenizer.detokenize(mid_seq[i])
 
255
  with open(f"outputs/output{i + 1}.mid", 'wb') as f:
256
  f.write(MIDI.score2midi(mid))
257
+ outputs.append(f"outputs/output{i + 1}.mid")
258
  end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
259
  create_msg("visualizer_append", [i, events]),
260
  create_msg("visualizer_end", i)]
261
  return *outputs, send_msgs(end_msgs)
262
 
263
 
264
+ def synthesis_task(mid):
265
+ return synthesizer.synthesis(MIDI.score2opus(mid))
266
+
267
+ def render_audio(model_name, mid_seq, should_render_audio):
268
+ if (not should_render_audio) or mid_seq is None:
269
+ outputs = [None] * OUTPUT_BATCH_SIZE
270
+ return tuple(outputs)
271
+ tokenizer = models[model_name].tokenizer
272
+ outputs = []
273
+ if not os.path.exists("outputs"):
274
+ os.mkdir("outputs")
275
+ audio_futures = []
276
+ for i in range(OUTPUT_BATCH_SIZE):
277
+ mid = tokenizer.detokenize(mid_seq[i])
278
+ audio_future = thread_pool.submit(synthesis_task, mid)
279
+ audio_futures.append(audio_future)
280
+ for future in audio_futures:
281
+ outputs.append((44100, future.result()))
282
+ return tuple(outputs)
283
+
284
+
285
  def undo_continuation(model_name, mid_seq, continuation_state):
286
  if mid_seq is None or len(continuation_state) < 2:
287
  return mid_seq, continuation_state, send_msgs([])
 
347
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
348
  opt = parser.parse_args()
349
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
350
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
351
  synthesizer = MidiSynthesizer(soundfont_path)
352
  models_info = {
353
  "generic pretrain model (tv2o-medium) by skytnt": ["skytnt/midi-model-tv2o-medium", "", "tv2o-medium"],
 
466
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
467
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=30)
468
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
469
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
470
  example3 = gr.Examples([[1, 0.95, 128], [1, 0.98, 20], [1, 0.98, 12]],
471
  [input_temp, input_top_p, input_top_k])
472
  run_btn = gr.Button("generate", variant="primary")
473
  # stop_btn = gr.Button("stop and output")
474
  output_midi_seq = gr.State()
475
  output_continuation_state = gr.State([0])
476
+ midi_outputs = []
477
+ audio_outputs = []
478
  with gr.Tabs(elem_id="output_tabs"):
479
  for i in range(OUTPUT_BATCH_SIZE):
480
  with gr.TabItem(f"output {i + 1}") as tab1:
481
  output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
482
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
483
  output_midi = gr.File(label="output midi", file_types=[".mid"])
484
+ midi_outputs.append(output_midi)
485
+ audio_outputs.append(output_audio)
486
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
487
  input_continuation_select, input_instruments, input_drum_kit, input_bpm,
488
  input_time_sig, input_key_sig, input_midi, input_midi_events,
 
492
  input_top_k, input_allow_cc],
493
  [output_midi_seq, output_continuation_state, input_seed, js_msg],
494
  concurrency_limit=10, queue=True)
495
+ finish_run_event = run_event.then(fn=finish_run,
496
+ inputs=[input_model, output_midi_seq],
497
+ outputs=midi_outputs + [js_msg],
498
+ queue=False)
499
+ finish_run_event.then(fn=render_audio,
500
+ inputs=[input_model, output_midi_seq, input_render_audio],
501
+ outputs=audio_outputs,
502
+ queue=False)
503
  # stop_btn.click(None, [], [], cancels=run_event,
504
  # queue=False)
505
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
506
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
507
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
508
+ thread_pool.shutdown()
midi_synthesizer.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import fluidsynth
2
  import numpy as np
3
 
@@ -9,14 +11,16 @@ class MidiSynthesizer:
9
  fl = fluidsynth.Synth(samplerate=float(sample_rate))
10
  sfid = fl.sfload(soundfont_path)
11
  self.devices = [[fl, sfid, False]]
 
12
 
13
  def get_fluidsynth(self):
14
  for device in self.devices:
15
  if not device[2]:
16
  device[2] = True
17
  return device
18
- fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
19
- sfid = fl.sfload(self.soundfont_path)
 
20
  device = [fl, sfid, True]
21
  self.devices.append(device)
22
  return device
 
1
+ from threading import Lock
2
+
3
  import fluidsynth
4
  import numpy as np
5
 
 
11
  fl = fluidsynth.Synth(samplerate=float(sample_rate))
12
  sfid = fl.sfload(soundfont_path)
13
  self.devices = [[fl, sfid, False]]
14
+ self.file_lock = Lock()
15
 
16
  def get_fluidsynth(self):
17
  for device in self.devices:
18
  if not device[2]:
19
  device[2] = True
20
  return device
21
+ with self.file_lock:
22
+ fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
23
+ sfid = fl.sfload(self.soundfont_path)
24
  device = [fl, sfid, True]
25
  self.devices.append(device)
26
  return device