skytnt commited on
Commit
d5353b5
·
1 Parent(s): 9d7618f

update tokenizer

Browse files
Files changed (2) hide show
  1. app.py +15 -6
  2. midi_tokenizer.py +111 -9
app.py CHANGED
@@ -121,7 +121,8 @@ def send_msgs(msgs):
121
  return json.dumps(msgs)
122
 
123
 
124
- def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt, seed, seed_rand,
 
125
  gen_events, temp, top_p, top_k, allow_cc):
126
  mid_seq = []
127
  bpm = int(bpm)
@@ -153,8 +154,11 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
153
  disable_patch_change = True
154
  disable_channels = [i for i in range(16) if i not in patches]
155
  elif mid is not None:
156
- eps = 4 if midi_opt else 0
157
- mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps)
 
 
 
158
  mid = np.asarray(mid, dtype=np.int64)
159
  mid = mid[:int(midi_events)]
160
  for token_seq in mid:
@@ -306,7 +310,10 @@ if __name__ == "__main__":
306
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
307
  step=1,
308
  value=128)
309
- input_midi_opt = gr.Checkbox(label="optimise midi (uncheck if your midi is generate from this model)", value=True)
 
 
 
310
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
311
  [input_midi, input_midi_events])
312
 
@@ -330,8 +337,10 @@ if __name__ == "__main__":
330
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
331
  output_midi = gr.File(label="output midi", file_types=[".mid"])
332
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
333
- input_midi, input_midi_events, input_midi_opt, input_seed, input_seed_rand,
334
- input_gen_events, input_temp, input_top_p, input_top_k, input_allow_cc],
 
 
335
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
336
  concurrency_limit=3)
337
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
 
121
  return json.dumps(msgs)
122
 
123
 
124
+ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
125
+ reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  mid_seq = []
128
  bpm = int(bpm)
 
154
  disable_patch_change = True
155
  disable_channels = [i for i in range(16) if i not in patches]
156
  elif mid is not None:
157
+ eps = 4 if reduce_cc_st else 0
158
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
159
+ remap_track_channel=remap_track_channel,
160
+ add_default_instr=add_default_instr,
161
+ remove_empty_channels=remove_empty_channels)
162
  mid = np.asarray(mid, dtype=np.int64)
163
  mid = mid[:int(midi_events)]
164
  for token_seq in mid:
 
310
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
311
  step=1,
312
  value=128)
313
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
314
+ input_remap_track_channel = gr.Checkbox(label="remap tracks and channels to have only one channel per track", value=True)
315
+ input_add_default_instr = gr.Checkbox(label="add a default instrument to channels that don't have an instrument", value=True)
316
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
317
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
318
  [input_midi, input_midi_events])
319
 
 
337
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
338
  output_midi = gr.File(label="output midi", file_types=[".mid"])
339
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
340
+ input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
341
+ input_add_default_instr, input_remove_empty_channels, input_seed,
342
+ input_seed_rand, input_gen_events, input_temp, input_top_p, input_top_k,
343
+ input_allow_cc],
344
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
345
  concurrency_limit=3)
346
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
midi_tokenizer.py CHANGED
@@ -42,9 +42,16 @@ class MIDITokenizer:
42
  tempo = int((60 / bpm) * 10 ** 6)
43
  return tempo
44
 
45
- def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4):
 
46
  ticks_per_beat = midi_score[0]
47
  event_list = {}
 
 
 
 
 
 
48
  for track_idx, track in enumerate(midi_score[1:129]):
49
  last_notes = {}
50
  patch_dict = {}
@@ -53,9 +60,18 @@ class MIDITokenizer:
53
  for event in track:
54
  if event[0] not in self.events:
55
  continue
 
56
  t = round(16 * event[1] / ticks_per_beat) # quantization
57
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
58
  if event[0] == "note":
 
 
 
 
 
 
 
 
59
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
60
  elif event[0] == "set_tempo":
61
  if new_event[4] == 0: # invalid tempo
@@ -68,12 +84,18 @@ class MIDITokenizer:
68
  key = tuple(new_event[:-1])
69
  if event[0] == "patch_change":
70
  c, p = event[2:]
 
 
71
  last_p = patch_dict.setdefault(c, None)
72
  if last_p == p:
73
  continue
74
  patch_dict[c] = p
 
 
75
  elif event[0] == "control_change":
76
  c, cc, v = event[2:]
 
 
77
  last_v = control_dict.setdefault((c, cc), 0)
78
  if abs(last_v - v) < cc_eps:
79
  continue
@@ -84,6 +106,13 @@ class MIDITokenizer:
84
  continue
85
  last_tempo = tempo
86
 
 
 
 
 
 
 
 
87
  if event[0] == "note": # to eliminate note overlap due to quantization
88
  cp = tuple(new_event[5:7])
89
  if cp in last_notes:
@@ -95,8 +124,79 @@ class MIDITokenizer:
95
  last_notes[cp] = (key, new_event)
96
  event_list[key] = new_event
97
  event_list = list(event_list.values())
98
- event_list = sorted(event_list, key=lambda e: e[1:4])
99
- midi_seq = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  setup_events = {}
101
  notes_in_setup = False
102
  for i, event in enumerate(event_list): # optimise setup
@@ -113,7 +213,7 @@ class MIDITokenizer:
113
  pre_event = event_list[i - 1]
114
  has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
115
  if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
116
- event_list = sorted(setup_events.values(), key=lambda e: 1 if e[0] == "note" else 0) + event_list[i:]
117
  break
118
  else:
119
  if event[0] == "note":
@@ -122,7 +222,10 @@ class MIDITokenizer:
122
  setup_events[key] = new_event
123
 
124
  last_t1 = 0
 
125
  for event in event_list:
 
 
126
  cur_t1 = event[1]
127
  event[1] = event[1] - last_t1
128
  tokens = self.event2tokens(event)
@@ -181,7 +284,7 @@ class MIDITokenizer:
181
  if track_idx not in tracks_dict:
182
  tracks_dict[track_idx] = []
183
  tracks_dict[track_idx].append([event[0], t] + event[4:])
184
- tracks = list(tracks_dict.values())
185
 
186
  for i in range(len(tracks)): # to eliminate note overlap
187
  track = tracks[i]
@@ -292,7 +395,6 @@ class MIDITokenizer:
292
  notes_bandwidth_list = []
293
  instruments = {}
294
  piano_channels = []
295
- undef_instrument = False
296
  abs_t1 = 0
297
  last_t = 0
298
  for tsi, tokens in enumerate(midi_seq):
@@ -309,7 +411,9 @@ class MIDITokenizer:
309
  time_hist[t2] += 1
310
  if c != 9: # ignore drum channel
311
  if c not in instruments:
312
- undef_instrument = True
 
 
313
  note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
314
  if last_t != t:
315
  notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
@@ -330,8 +434,6 @@ class MIDITokenizer:
330
  reasons.append("total_min")
331
  if total_notes > total_notes_max:
332
  reasons.append("total_max")
333
- if undef_instrument:
334
- reasons.append("undef_instr")
335
  if len(note_windows) == 0 and total_notes > 0:
336
  reasons.append("drum_only")
337
  if reasons:
 
42
  tempo = int((60 / bpm) * 10 ** 6)
43
  return tempo
44
 
45
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
46
+ remap_track_channel=False, add_default_instr=False, remove_empty_channels=False):
47
  ticks_per_beat = midi_score[0]
48
  event_list = {}
49
+ track_idx_map = {i: dict() for i in range(16)}
50
+ track_idx_dict = {}
51
+ channels = []
52
+ patch_channels = []
53
+ empty_channels = [True]*16
54
+ channel_note_tracks = {i: list() for i in range(16)}
55
  for track_idx, track in enumerate(midi_score[1:129]):
56
  last_notes = {}
57
  patch_dict = {}
 
60
  for event in track:
61
  if event[0] not in self.events:
62
  continue
63
+ c = -1
64
  t = round(16 * event[1] / ticks_per_beat) # quantization
65
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
66
  if event[0] == "note":
67
+ c = event[3]
68
+ if c > 15 or c < 0:
69
+ continue
70
+ empty_channels[c] = False
71
+ track_idx_dict.setdefault(c, track_idx)
72
+ note_tracks = channel_note_tracks[c]
73
+ if track_idx not in note_tracks:
74
+ note_tracks.append(track_idx)
75
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
76
  elif event[0] == "set_tempo":
77
  if new_event[4] == 0: # invalid tempo
 
84
  key = tuple(new_event[:-1])
85
  if event[0] == "patch_change":
86
  c, p = event[2:]
87
+ if c > 15 or c < 0:
88
+ continue
89
  last_p = patch_dict.setdefault(c, None)
90
  if last_p == p:
91
  continue
92
  patch_dict[c] = p
93
+ if c not in patch_channels:
94
+ patch_channels.append(c)
95
  elif event[0] == "control_change":
96
  c, cc, v = event[2:]
97
+ if c > 15 or c < 0:
98
+ continue
99
  last_v = control_dict.setdefault((c, cc), 0)
100
  if abs(last_v - v) < cc_eps:
101
  continue
 
106
  continue
107
  last_tempo = tempo
108
 
109
+ if c != -1:
110
+ if c not in channels:
111
+ channels.append(c)
112
+ tr_map = track_idx_map[c]
113
+ if track_idx not in tr_map:
114
+ tr_map[track_idx] = 0
115
+
116
  if event[0] == "note": # to eliminate note overlap due to quantization
117
  cp = tuple(new_event[5:7])
118
  if cp in last_notes:
 
124
  last_notes[cp] = (key, new_event)
125
  event_list[key] = new_event
126
  event_list = list(event_list.values())
127
+
128
+ empty_channels = [c for c in channels if empty_channels[c]]
129
+
130
+ if remap_track_channel:
131
+ patch_channels = []
132
+ channels_count = 0
133
+ channels_map = {9: 9} if 9 in channels else {}
134
+ for c in channels:
135
+ if c == 9:
136
+ continue
137
+ channels_map[c] = channels_count
138
+ channels_count += 1
139
+ if channels_count == 9:
140
+ channels_count = 10
141
+ channels = list(channels_map.values())
142
+
143
+ track_count = 0
144
+ track_idx_map_order = [k for k,v in sorted(list(channels_map.items()), key=lambda x: x[1])]
145
+ for c in track_idx_map_order: # tracks not to remove
146
+ if remove_empty_channels and c in empty_channels:
147
+ continue
148
+ tr_map = track_idx_map[c]
149
+ for track_idx in tr_map:
150
+ note_tracks = channel_note_tracks[c]
151
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
152
+ continue
153
+ track_count += 1
154
+ tr_map[track_idx] = track_count
155
+ for c in track_idx_map_order: # tracks to remove
156
+ if not (remove_empty_channels and c in empty_channels):
157
+ continue
158
+ tr_map = track_idx_map[c]
159
+ for track_idx in tr_map:
160
+ note_tracks = channel_note_tracks[c]
161
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
162
+ continue
163
+ track_count += 1
164
+ tr_map[track_idx] = track_count
165
+
166
+ empty_channels = [channels_map[c] for c in empty_channels]
167
+
168
+ for event in event_list:
169
+ name = event[0]
170
+ track_idx = event[3]
171
+ if name == "note":
172
+ c = event[5]
173
+ event[5] = channels_map[c]
174
+ event[3] = track_idx_map[c][track_idx]
175
+ track_idx_dict[event[5]] = event[3]
176
+ elif name == "set_tempo":
177
+ event[3] = 0
178
+ elif name == "control_change" or name == "patch_change":
179
+ c = event[4]
180
+ event[4] = channels_map[c]
181
+ tr_map = track_idx_map[c]
182
+ # move the event to first track of the channel if it's original track is empty
183
+ note_tracks = channel_note_tracks[c]
184
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
185
+ track_idx = channel_note_tracks[c][0]
186
+ new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
187
+ event[3] = new_track_idx
188
+ if name == "patch_change" and event[4] not in patch_channels:
189
+ patch_channels.append(event[4])
190
+
191
+ if add_default_instr:
192
+ for c in channels:
193
+ if c not in patch_channels:
194
+ event_list.append(["patch_change", 0,0, track_idx_dict[c], c, 0])
195
+
196
+ events_name_order = {"set_tempo":0, "patch_change":1, "control_change":2, "note":3}
197
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
198
+ event_list = sorted(event_list, key=events_order)
199
+
200
  setup_events = {}
201
  notes_in_setup = False
202
  for i, event in enumerate(event_list): # optimise setup
 
213
  pre_event = event_list[i - 1]
214
  has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
215
  if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
216
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
217
  break
218
  else:
219
  if event[0] == "note":
 
222
  setup_events[key] = new_event
223
 
224
  last_t1 = 0
225
+ midi_seq = []
226
  for event in event_list:
227
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
228
+ continue
229
  cur_t1 = event[1]
230
  event[1] = event[1] - last_t1
231
  tokens = self.event2tokens(event)
 
284
  if track_idx not in tracks_dict:
285
  tracks_dict[track_idx] = []
286
  tracks_dict[track_idx].append([event[0], t] + event[4:])
287
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
288
 
289
  for i in range(len(tracks)): # to eliminate note overlap
290
  track = tracks[i]
 
395
  notes_bandwidth_list = []
396
  instruments = {}
397
  piano_channels = []
 
398
  abs_t1 = 0
399
  last_t = 0
400
  for tsi, tokens in enumerate(midi_seq):
 
411
  time_hist[t2] += 1
412
  if c != 9: # ignore drum channel
413
  if c not in instruments:
414
+ instruments[c] = 0
415
+ if c not in piano_channels:
416
+ piano_channels.append(c)
417
  note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
418
  if last_t != t:
419
  notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
 
434
  reasons.append("total_min")
435
  if total_notes > total_notes_max:
436
  reasons.append("total_max")
 
 
437
  if len(note_windows) == 0 and total_notes > 0:
438
  reasons.append("drum_only")
439
  if reasons: