skytnt commited on
Commit
239bed9
1 Parent(s): 3fe9868

signature options and MidiSynthesizer optimise

Browse files
Files changed (2) hide show
  1. app.py +43 -8
  2. midi_synthesizer.py +70 -47
app.py CHANGED
@@ -11,7 +11,7 @@ import tqdm
11
  from huggingface_hub import hf_hub_download
12
 
13
  import MIDI
14
- from midi_synthesizer import synthesis
15
  from midi_tokenizer import MIDITokenizer
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
@@ -121,12 +121,28 @@ def send_msgs(msgs):
121
  return json.dumps(msgs)
122
 
123
 
124
- def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  model = models[model_name]
128
  tokenizer = model[2]
129
  bpm = int(bpm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  gen_events = int(gen_events)
131
  max_len = gen_events
132
  if seed_rand:
@@ -137,6 +153,11 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
137
  if tab == 0:
138
  i = 0
139
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
 
 
 
 
 
140
  if bpm != 0:
141
  mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
142
  patches = {}
@@ -148,7 +169,7 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
148
  if drum_kit != "None":
149
  patches[9] = drum_kits2number[drum_kit]
150
  for i, (c, p) in enumerate(patches.items()):
151
- mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
152
  mid_seq = mid
153
  mid = np.asarray(mid, dtype=np.int64)
154
  if len(instruments) > 0:
@@ -181,11 +202,11 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
181
  init_msgs = [create_msg("visualizer_clear", tokenizer.version),
182
  create_msg("visualizer_append", events)]
183
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
184
- t = time.time() + 1
185
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
186
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
187
  disable_channels=disable_channels, generator=generator)
188
  events = []
 
189
  for i, token_seq in enumerate(midi_generator):
190
  token_seq = token_seq.tolist()
191
  mid_seq.append(token_seq)
@@ -200,7 +221,7 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
200
  mid = tokenizer.detokenize(mid_seq)
201
  with open(f"output.mid", 'wb') as f:
202
  f.write(MIDI.score2midi(mid))
203
- audio = synthesis(MIDI.score2opus(mid), soundfont_path)
204
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
205
  yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
206
 
@@ -212,7 +233,7 @@ def cancel_run(model_name, mid_seq):
212
  mid = tokenizer.detokenize(mid_seq)
213
  with open(f"output.mid", 'wb') as f:
214
  f.write(MIDI.score2midi(mid))
215
- audio = synthesis(MIDI.score2opus(mid), soundfont_path)
216
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
217
  return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
218
 
@@ -268,6 +289,8 @@ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Elec
268
  40: "Blush", 48: "Orchestra"}
269
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
270
  drum_kits2number = {v: k for k, v in number2drum_kits.items()}
 
 
271
 
272
  if __name__ == "__main__":
273
  parser = argparse.ArgumentParser()
@@ -276,6 +299,7 @@ if __name__ == "__main__":
276
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
277
  opt = parser.parse_args()
278
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
 
279
  models_info = {"generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
280
  "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
281
  "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
@@ -324,6 +348,16 @@ if __name__ == "__main__":
324
  input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
325
  step=1,
326
  value=0)
 
 
 
 
 
 
 
 
 
 
327
  example1 = gr.Examples([
328
  [[], "None"],
329
  [["Acoustic Grand"], "None"],
@@ -375,8 +409,9 @@ if __name__ == "__main__":
375
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
376
  output_midi = gr.File(label="output midi", file_types=[".mid"])
377
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, input_instruments,
378
- input_drum_kit, input_bpm, input_midi, input_midi_events, input_reduce_cc_st,
379
- input_remap_track_channel, input_add_default_instr, input_remove_empty_channels,
 
380
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
381
  input_top_k, input_allow_cc],
382
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
 
11
  from huggingface_hub import hf_hub_download
12
 
13
  import MIDI
14
+ from midi_synthesizer import MidiSynthesizer
15
  from midi_tokenizer import MIDITokenizer
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
 
121
  return json.dumps(msgs)
122
 
123
 
124
+ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  model = models[model_name]
128
  tokenizer = model[2]
129
  bpm = int(bpm)
130
+ if time_sig == "auto":
131
+ time_sig = None
132
+ time_sig_nn = 4
133
+ time_sig_dd = 2
134
+ else:
135
+ time_sig_nn, time_sig_dd = time_sig.split('/')
136
+ time_sig_nn = int(time_sig_nn)
137
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
138
+ if key_sig == 0:
139
+ key_sig = None
140
+ key_sig_sf = 0
141
+ key_sig_mi = 0
142
+ else:
143
+ key_sig = (key_sig - 1)
144
+ key_sig_sf = key_sig // 2 - 7
145
+ key_sig_mi = key_sig % 2
146
  gen_events = int(gen_events)
147
  max_len = gen_events
148
  if seed_rand:
 
153
  if tab == 0:
154
  i = 0
155
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
156
+ if tokenizer.version == "v2":
157
+ if time_sig is not None:
158
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
159
+ if key_sig is not None:
160
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
161
  if bpm != 0:
162
  mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
163
  patches = {}
 
169
  if drum_kit != "None":
170
  patches[9] = drum_kits2number[drum_kit]
171
  for i, (c, p) in enumerate(patches.items()):
172
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
173
  mid_seq = mid
174
  mid = np.asarray(mid, dtype=np.int64)
175
  if len(instruments) > 0:
 
202
  init_msgs = [create_msg("visualizer_clear", tokenizer.version),
203
  create_msg("visualizer_append", events)]
204
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
 
205
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
206
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
207
  disable_channels=disable_channels, generator=generator)
208
  events = []
209
+ t = time.time() + 1
210
  for i, token_seq in enumerate(midi_generator):
211
  token_seq = token_seq.tolist()
212
  mid_seq.append(token_seq)
 
221
  mid = tokenizer.detokenize(mid_seq)
222
  with open(f"output.mid", 'wb') as f:
223
  f.write(MIDI.score2midi(mid))
224
+ audio = synthesizer.synthesis(MIDI.score2opus(mid))
225
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
226
  yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
227
 
 
233
  mid = tokenizer.detokenize(mid_seq)
234
  with open(f"output.mid", 'wb') as f:
235
  f.write(MIDI.score2midi(mid))
236
+ audio = synthesizer.synthesis(MIDI.score2opus(mid))
237
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
238
  return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
239
 
 
289
  40: "Blush", 48: "Orchestra"}
290
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
291
  drum_kits2number = {v: k for k, v in number2drum_kits.items()}
292
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
293
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
294
 
295
  if __name__ == "__main__":
296
  parser = argparse.ArgumentParser()
 
299
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
300
  opt = parser.parse_args()
301
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
302
+ synthesizer = MidiSynthesizer(soundfont_path)
303
  models_info = {"generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
304
  "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
305
  "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
 
348
  input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
349
  step=1,
350
  value=0)
351
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
352
+ value="auto",
353
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
354
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
355
+ )
356
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
357
+ value="auto",
358
+ choices=["auto"] + key_signatures,
359
+ type="index"
360
+ )
361
  example1 = gr.Examples([
362
  [[], "None"],
363
  [["Acoustic Grand"], "None"],
 
409
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
410
  output_midi = gr.File(label="output midi", file_types=[".mid"])
411
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, input_instruments,
412
+ input_drum_kit, input_bpm, input_time_sig, input_key_sig, input_midi,
413
+ input_midi_events, input_reduce_cc_st, input_remap_track_channel,
414
+ input_add_default_instr, input_remove_empty_channels,
415
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
416
  input_top_k, input_allow_cc],
417
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
midi_synthesizer.py CHANGED
@@ -1,53 +1,76 @@
1
  import fluidsynth
2
  import numpy as np
3
 
 
 
 
 
 
 
 
4
 
5
- def synthesis(midi_opus, soundfont_path, sample_rate=44100):
6
- ticks_per_beat = midi_opus[0]
7
- event_list = []
8
- for track_idx, track in enumerate(midi_opus[1:]):
9
- abs_t = 0
10
- for event in track:
11
- abs_t += event[1]
12
- event_new = [*event]
13
- event_new[1] = abs_t
14
- event_list.append(event_new)
15
- event_list = sorted(event_list, key=lambda e: e[1])
16
 
17
- tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
18
- ss = np.empty((0, 2), dtype=np.int16)
19
- fl = fluidsynth.Synth(samplerate=float(sample_rate))
20
- sfid = fl.sfload(soundfont_path)
21
- last_t = 0
22
- for c in range(16):
23
- fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
24
- for event in event_list:
25
- name = event[0]
26
- sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
27
- sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
28
- last_t = event[1]
29
- if sample_len > 0:
30
- sample = fl.get_samples(sample_len).reshape(sample_len, 2)
31
- ss = np.concatenate([ss, sample])
32
- if name == "set_tempo":
33
- tempo = event[2]
34
- elif name == "patch_change":
35
- c, p = event[2:4]
36
- fl.program_select(c, sfid, 128 if c == 9 else 0, p)
37
- elif name == "control_change":
38
- c, cc, v = event[2:5]
39
- fl.cc(c, cc, v)
40
- elif name == "note_on" and event[3] > 0:
41
- c, p, v = event[2:5]
42
- fl.noteon(c, p, v)
43
- elif name == "note_off" or (name == "note_on" and event[3] == 0):
44
- c, p = event[2:4]
45
- fl.noteoff(c, p)
46
 
47
- fl.delete()
48
- if ss.shape[0] > 0:
49
- max_val = np.abs(ss).max()
50
- if max_val != 0:
51
- ss = (ss / max_val) * np.iinfo(np.int16).max
52
- ss = ss.astype(np.int16)
53
- return ss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import fluidsynth
2
  import numpy as np
3
 
4
+ class MidiSynthesizer:
5
+ def __init__(self, soundfont_path, sample_rate=44100):
6
+ self.soundfont_path = soundfont_path
7
+ self.sample_rate = sample_rate
8
+ fl = fluidsynth.Synth(samplerate=float(sample_rate))
9
+ sfid = fl.sfload(soundfont_path)
10
+ self.devices = [[fl, sfid, False]]
11
 
12
+ def get_fluidsynth(self):
13
+ for device in self.devices:
14
+ if not device[2]:
15
+ device[2] = True
16
+ return device
17
+ fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
18
+ sfid = fl.sfload(self.soundfont_path)
19
+ device = [fl, sfid, True]
20
+ self.devices.append(device)
21
+ return device
 
22
 
23
+ def release_fluidsynth(self, device):
24
+ device[0].system_reset()
25
+ device[0].get_samples(self.sample_rate*5) # wait for silence
26
+ device[2] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def synthesis(self, midi_opus):
29
+ ticks_per_beat = midi_opus[0]
30
+ event_list = []
31
+ for track_idx, track in enumerate(midi_opus[1:]):
32
+ abs_t = 0
33
+ for event in track:
34
+ abs_t += event[1]
35
+ event_new = [*event]
36
+ event_new[1] = abs_t
37
+ event_list.append(event_new)
38
+ event_list = sorted(event_list, key=lambda e: e[1])
39
+
40
+ tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
41
+ ss = np.empty((0, 2), dtype=np.int16)
42
+ device = self.get_fluidsynth()
43
+ fl, sfid = device[:-1]
44
+ last_t = 0
45
+ for c in range(16):
46
+ fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
47
+ for event in event_list:
48
+ name = event[0]
49
+ sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
50
+ sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
51
+ last_t = event[1]
52
+ if sample_len > 0:
53
+ sample = fl.get_samples(sample_len).reshape(sample_len, 2)
54
+ ss = np.concatenate([ss, sample])
55
+ if name == "set_tempo":
56
+ tempo = event[2]
57
+ elif name == "patch_change":
58
+ c, p = event[2:4]
59
+ fl.program_select(c, sfid, 128 if c == 9 else 0, p)
60
+ elif name == "control_change":
61
+ c, cc, v = event[2:5]
62
+ fl.cc(c, cc, v)
63
+ elif name == "note_on" and event[3] > 0:
64
+ c, p, v = event[2:5]
65
+ fl.noteon(c, p, v)
66
+ elif name == "note_off" or (name == "note_on" and event[3] == 0):
67
+ c, p = event[2:4]
68
+ fl.noteoff(c, p)
69
+
70
+ self.release_fluidsynth(device)
71
+ if ss.shape[0] > 0:
72
+ max_val = np.abs(ss).max()
73
+ if max_val != 0:
74
+ ss = (ss / max_val) * np.iinfo(np.int16).max
75
+ ss = ss.astype(np.int16)
76
+ return ss