Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chord_recognition
|
2 |
+
import numpy as np
|
3 |
+
import miditoolkit
|
4 |
+
import copy
|
5 |
+
|
6 |
+
# parameters for input
|
7 |
+
DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=np.int)
|
8 |
+
DEFAULT_FRACTION = 16
|
9 |
+
DEFAULT_DURATION_BINS = np.arange(60, 3841, 60, dtype=int)
|
10 |
+
DEFAULT_TEMPO_INTERVALS = [range(30, 90), range(90, 150), range(150, 210)]
|
11 |
+
|
12 |
+
# parameters for output
|
13 |
+
DEFAULT_RESOLUTION = 480
|
14 |
+
|
15 |
+
# define "Item" for general storage
|
16 |
+
class Item(object):
|
17 |
+
def __init__(self, name, start, end, velocity, pitch):
|
18 |
+
self.name = name
|
19 |
+
self.start = start
|
20 |
+
self.end = end
|
21 |
+
self.velocity = velocity
|
22 |
+
self.pitch = pitch
|
23 |
+
|
24 |
+
def __repr__(self):
|
25 |
+
return 'Item(name={}, start={}, end={}, velocity={}, pitch={})'.format(
|
26 |
+
self.name, self.start, self.end, self.velocity, self.pitch)
|
27 |
+
|
28 |
+
# read notes and tempo changes from midi (assume there is only one track)
|
29 |
+
def read_items(file_path):
|
30 |
+
midi_obj = miditoolkit.midi.parser.MidiFile(file_path)
|
31 |
+
# note
|
32 |
+
note_items = []
|
33 |
+
notes = midi_obj.instruments[0].notes
|
34 |
+
notes.sort(key=lambda x: (x.start, x.pitch))
|
35 |
+
for note in notes:
|
36 |
+
note_items.append(Item(
|
37 |
+
name='Note',
|
38 |
+
start=note.start,
|
39 |
+
end=note.end,
|
40 |
+
velocity=note.velocity,
|
41 |
+
pitch=note.pitch))
|
42 |
+
note_items.sort(key=lambda x: x.start)
|
43 |
+
# tempo
|
44 |
+
tempo_items = []
|
45 |
+
for tempo in midi_obj.tempo_changes:
|
46 |
+
tempo_items.append(Item(
|
47 |
+
name='Tempo',
|
48 |
+
start=tempo.time,
|
49 |
+
end=None,
|
50 |
+
velocity=None,
|
51 |
+
pitch=int(tempo.tempo)))
|
52 |
+
tempo_items.sort(key=lambda x: x.start)
|
53 |
+
# expand to all beat
|
54 |
+
max_tick = tempo_items[-1].start
|
55 |
+
existing_ticks = {item.start: item.pitch for item in tempo_items}
|
56 |
+
wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION)
|
57 |
+
output = []
|
58 |
+
for tick in wanted_ticks:
|
59 |
+
if tick in existing_ticks:
|
60 |
+
output.append(Item(
|
61 |
+
name='Tempo',
|
62 |
+
start=tick,
|
63 |
+
end=None,
|
64 |
+
velocity=None,
|
65 |
+
pitch=existing_ticks[tick]))
|
66 |
+
else:
|
67 |
+
output.append(Item(
|
68 |
+
name='Tempo',
|
69 |
+
start=tick,
|
70 |
+
end=None,
|
71 |
+
velocity=None,
|
72 |
+
pitch=output[-1].pitch))
|
73 |
+
tempo_items = output
|
74 |
+
return note_items, tempo_items
|
75 |
+
|
76 |
+
# quantize items
|
77 |
+
def quantize_items(items, ticks=120):
|
78 |
+
# grid
|
79 |
+
grids = np.arange(0, items[-1].start, ticks, dtype=int)
|
80 |
+
# process
|
81 |
+
for item in items:
|
82 |
+
index = np.argmin(abs(grids - item.start))
|
83 |
+
shift = grids[index] - item.start
|
84 |
+
item.start += shift
|
85 |
+
item.end += shift
|
86 |
+
return items
|
87 |
+
|
88 |
+
# extract chord
|
89 |
+
def extract_chords(items):
|
90 |
+
method = chord_recognition.MIDIChord()
|
91 |
+
chords = method.extract(notes=items)
|
92 |
+
output = []
|
93 |
+
for chord in chords:
|
94 |
+
output.append(Item(
|
95 |
+
name='Chord',
|
96 |
+
start=chord[0],
|
97 |
+
end=chord[1],
|
98 |
+
velocity=None,
|
99 |
+
pitch=chord[2].split('/')[0]))
|
100 |
+
return output
|
101 |
+
|
102 |
+
# group items
|
103 |
+
def group_items(items, max_time, ticks_per_bar=DEFAULT_RESOLUTION*4):
|
104 |
+
items.sort(key=lambda x: x.start)
|
105 |
+
downbeats = np.arange(0, max_time+ticks_per_bar, ticks_per_bar)
|
106 |
+
groups = []
|
107 |
+
for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
|
108 |
+
insiders = []
|
109 |
+
for item in items:
|
110 |
+
if (item.start >= db1) and (item.start < db2):
|
111 |
+
insiders.append(item)
|
112 |
+
overall = [db1] + insiders + [db2]
|
113 |
+
groups.append(overall)
|
114 |
+
return groups
|
115 |
+
|
116 |
+
# define "Event" for event storage
|
117 |
+
class Event(object):
|
118 |
+
def __init__(self, name, time, value, text):
|
119 |
+
self.name = name
|
120 |
+
self.time = time
|
121 |
+
self.value = value
|
122 |
+
self.text = text
|
123 |
+
|
124 |
+
def __repr__(self):
|
125 |
+
return 'Event(name={}, time={}, value={}, text={})'.format(
|
126 |
+
self.name, self.time, self.value, self.text)
|
127 |
+
|
128 |
+
# item to event
|
129 |
+
def item2event(groups):
|
130 |
+
events = []
|
131 |
+
n_downbeat = 0
|
132 |
+
for i in range(len(groups)):
|
133 |
+
if 'Note' not in [item.name for item in groups[i][1:-1]]:
|
134 |
+
continue
|
135 |
+
bar_st, bar_et = groups[i][0], groups[i][-1]
|
136 |
+
n_downbeat += 1
|
137 |
+
events.append(Event(
|
138 |
+
name='Bar',
|
139 |
+
time=None,
|
140 |
+
value=None,
|
141 |
+
text='{}'.format(n_downbeat)))
|
142 |
+
for item in groups[i][1:-1]:
|
143 |
+
# position
|
144 |
+
flags = np.linspace(bar_st, bar_et, DEFAULT_FRACTION, endpoint=False)
|
145 |
+
index = np.argmin(abs(flags-item.start))
|
146 |
+
events.append(Event(
|
147 |
+
name='Position',
|
148 |
+
time=item.start,
|
149 |
+
value='{}/{}'.format(index+1, DEFAULT_FRACTION),
|
150 |
+
text='{}'.format(item.start)))
|
151 |
+
if item.name == 'Note':
|
152 |
+
# velocity
|
153 |
+
velocity_index = np.searchsorted(
|
154 |
+
DEFAULT_VELOCITY_BINS,
|
155 |
+
item.velocity,
|
156 |
+
side='right') - 1
|
157 |
+
events.append(Event(
|
158 |
+
name='Note Velocity',
|
159 |
+
time=item.start,
|
160 |
+
value=velocity_index,
|
161 |
+
text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index])))
|
162 |
+
# pitch
|
163 |
+
events.append(Event(
|
164 |
+
name='Note On',
|
165 |
+
time=item.start,
|
166 |
+
value=item.pitch,
|
167 |
+
text='{}'.format(item.pitch)))
|
168 |
+
# duration
|
169 |
+
duration = item.end - item.start
|
170 |
+
index = np.argmin(abs(DEFAULT_DURATION_BINS-duration))
|
171 |
+
events.append(Event(
|
172 |
+
name='Note Duration',
|
173 |
+
time=item.start,
|
174 |
+
value=index,
|
175 |
+
text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index])))
|
176 |
+
elif item.name == 'Chord':
|
177 |
+
events.append(Event(
|
178 |
+
name='Chord',
|
179 |
+
time=item.start,
|
180 |
+
value=item.pitch,
|
181 |
+
text='{}'.format(item.pitch)))
|
182 |
+
elif item.name == 'Tempo':
|
183 |
+
tempo = item.pitch
|
184 |
+
if tempo in DEFAULT_TEMPO_INTERVALS[0]:
|
185 |
+
tempo_style = Event('Tempo Class', item.start, 'slow', None)
|
186 |
+
tempo_value = Event('Tempo Value', item.start,
|
187 |
+
tempo-DEFAULT_TEMPO_INTERVALS[0].start, None)
|
188 |
+
elif tempo in DEFAULT_TEMPO_INTERVALS[1]:
|
189 |
+
tempo_style = Event('Tempo Class', item.start, 'mid', None)
|
190 |
+
tempo_value = Event('Tempo Value', item.start,
|
191 |
+
tempo-DEFAULT_TEMPO_INTERVALS[1].start, None)
|
192 |
+
elif tempo in DEFAULT_TEMPO_INTERVALS[2]:
|
193 |
+
tempo_style = Event('Tempo Class', item.start, 'fast', None)
|
194 |
+
tempo_value = Event('Tempo Value', item.start,
|
195 |
+
tempo-DEFAULT_TEMPO_INTERVALS[2].start, None)
|
196 |
+
elif tempo < DEFAULT_TEMPO_INTERVALS[0].start:
|
197 |
+
tempo_style = Event('Tempo Class', item.start, 'slow', None)
|
198 |
+
tempo_value = Event('Tempo Value', item.start, 0, None)
|
199 |
+
elif tempo > DEFAULT_TEMPO_INTERVALS[2].stop:
|
200 |
+
tempo_style = Event('Tempo Class', item.start, 'fast', None)
|
201 |
+
tempo_value = Event('Tempo Value', item.start, 59, None)
|
202 |
+
events.append(tempo_style)
|
203 |
+
events.append(tempo_value)
|
204 |
+
return events
|
205 |
+
|
206 |
+
#############################################################################################
|
207 |
+
# WRITE MIDI
|
208 |
+
#############################################################################################
|
209 |
+
def word_to_event(words, word2event):
|
210 |
+
events = []
|
211 |
+
for word in words:
|
212 |
+
event_name, event_value = word2event.get(word).split('_')
|
213 |
+
events.append(Event(event_name, None, event_value, None))
|
214 |
+
return events
|
215 |
+
|
216 |
+
def write_midi(words, word2event, output_path, prompt_path=None):
|
217 |
+
events = word_to_event(words, word2event)
|
218 |
+
# get downbeat and note (no time)
|
219 |
+
temp_notes = []
|
220 |
+
temp_chords = []
|
221 |
+
temp_tempos = []
|
222 |
+
for i in range(len(events)-3):
|
223 |
+
if events[i].name == 'Bar' and i > 0:
|
224 |
+
temp_notes.append('Bar')
|
225 |
+
temp_chords.append('Bar')
|
226 |
+
temp_tempos.append('Bar')
|
227 |
+
elif events[i].name == 'Position' and \
|
228 |
+
events[i+1].name == 'Note Velocity' and \
|
229 |
+
events[i+2].name == 'Note On' and \
|
230 |
+
events[i+3].name == 'Note Duration':
|
231 |
+
# start time and end time from position
|
232 |
+
position = int(events[i].value.split('/')[0]) - 1
|
233 |
+
# velocity
|
234 |
+
index = int(events[i+1].value)
|
235 |
+
velocity = int(DEFAULT_VELOCITY_BINS[index])
|
236 |
+
# pitch
|
237 |
+
pitch = int(events[i+2].value)
|
238 |
+
# duration
|
239 |
+
index = int(events[i+3].value)
|
240 |
+
duration = DEFAULT_DURATION_BINS[index]
|
241 |
+
# adding
|
242 |
+
temp_notes.append([position, velocity, pitch, duration])
|
243 |
+
elif events[i].name == 'Position' and events[i+1].name == 'Chord':
|
244 |
+
position = int(events[i].value.split('/')[0]) - 1
|
245 |
+
temp_chords.append([position, events[i+1].value])
|
246 |
+
elif events[i].name == 'Position' and \
|
247 |
+
events[i+1].name == 'Tempo Class' and \
|
248 |
+
events[i+2].name == 'Tempo Value':
|
249 |
+
position = int(events[i].value.split('/')[0]) - 1
|
250 |
+
if events[i+1].value == 'slow':
|
251 |
+
tempo = DEFAULT_TEMPO_INTERVALS[0].start + int(events[i+2].value)
|
252 |
+
elif events[i+1].value == 'mid':
|
253 |
+
tempo = DEFAULT_TEMPO_INTERVALS[1].start + int(events[i+2].value)
|
254 |
+
elif events[i+1].value == 'fast':
|
255 |
+
tempo = DEFAULT_TEMPO_INTERVALS[2].start + int(events[i+2].value)
|
256 |
+
temp_tempos.append([position, tempo])
|
257 |
+
# get specific time for notes
|
258 |
+
ticks_per_beat = DEFAULT_RESOLUTION
|
259 |
+
ticks_per_bar = DEFAULT_RESOLUTION * 4 # assume 4/4
|
260 |
+
notes = []
|
261 |
+
current_bar = 0
|
262 |
+
for note in temp_notes:
|
263 |
+
if note == 'Bar':
|
264 |
+
current_bar += 1
|
265 |
+
else:
|
266 |
+
position, velocity, pitch, duration = note
|
267 |
+
# position (start time)
|
268 |
+
current_bar_st = current_bar * ticks_per_bar
|
269 |
+
current_bar_et = (current_bar + 1) * ticks_per_bar
|
270 |
+
flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
|
271 |
+
st = flags[position]
|
272 |
+
# duration (end time)
|
273 |
+
et = st + duration
|
274 |
+
notes.append(miditoolkit.Note(velocity, pitch, st, et))
|
275 |
+
# get specific time for chords
|
276 |
+
if len(temp_chords) > 0:
|
277 |
+
chords = []
|
278 |
+
current_bar = 0
|
279 |
+
for chord in temp_chords:
|
280 |
+
if chord == 'Bar':
|
281 |
+
current_bar += 1
|
282 |
+
else:
|
283 |
+
position, value = chord
|
284 |
+
# position (start time)
|
285 |
+
current_bar_st = current_bar * ticks_per_bar
|
286 |
+
current_bar_et = (current_bar + 1) * ticks_per_bar
|
287 |
+
flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
|
288 |
+
st = flags[position]
|
289 |
+
chords.append([st, value])
|
290 |
+
# get specific time for tempos
|
291 |
+
tempos = []
|
292 |
+
current_bar = 0
|
293 |
+
for tempo in temp_tempos:
|
294 |
+
if tempo == 'Bar':
|
295 |
+
current_bar += 1
|
296 |
+
else:
|
297 |
+
position, value = tempo
|
298 |
+
# position (start time)
|
299 |
+
current_bar_st = current_bar * ticks_per_bar
|
300 |
+
current_bar_et = (current_bar + 1) * ticks_per_bar
|
301 |
+
flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
|
302 |
+
st = flags[position]
|
303 |
+
tempos.append([int(st), value])
|
304 |
+
# write
|
305 |
+
if prompt_path:
|
306 |
+
midi = miditoolkit.midi.parser.MidiFile(prompt_path)
|
307 |
+
#
|
308 |
+
last_time = DEFAULT_RESOLUTION * 4 * 4
|
309 |
+
# note shift
|
310 |
+
for note in notes:
|
311 |
+
note.start += last_time
|
312 |
+
note.end += last_time
|
313 |
+
midi.instruments[0].notes.extend(notes)
|
314 |
+
# tempo changes
|
315 |
+
temp_tempos = []
|
316 |
+
for tempo in midi.tempo_changes:
|
317 |
+
if tempo.time < DEFAULT_RESOLUTION*4*4:
|
318 |
+
temp_tempos.append(tempo)
|
319 |
+
else:
|
320 |
+
break
|
321 |
+
for st, bpm in tempos:
|
322 |
+
st += last_time
|
323 |
+
temp_tempos.append(miditoolkit.midi.containers.TempoChange(bpm, st))
|
324 |
+
midi.tempo_changes = temp_tempos
|
325 |
+
# write chord into marker
|
326 |
+
if len(temp_chords) > 0:
|
327 |
+
for c in chords:
|
328 |
+
midi.markers.append(
|
329 |
+
miditoolkit.midi.containers.Marker(text=c[1], time=c[0]+last_time))
|
330 |
+
else:
|
331 |
+
midi = miditoolkit.midi.parser.MidiFile()
|
332 |
+
midi.ticks_per_beat = DEFAULT_RESOLUTION
|
333 |
+
# write instrument
|
334 |
+
inst = miditoolkit.midi.containers.Instrument(0, is_drum=False)
|
335 |
+
inst.notes = notes
|
336 |
+
midi.instruments.append(inst)
|
337 |
+
# write tempo
|
338 |
+
tempo_changes = []
|
339 |
+
for st, bpm in tempos:
|
340 |
+
tempo_changes.append(miditoolkit.midi.containers.TempoChange(bpm, st))
|
341 |
+
midi.tempo_changes = tempo_changes
|
342 |
+
# write chord into marker
|
343 |
+
if len(temp_chords) > 0:
|
344 |
+
for c in chords:
|
345 |
+
midi.markers.append(
|
346 |
+
miditoolkit.midi.containers.Marker(text=c[1], time=c[0]))
|
347 |
+
# write
|
348 |
+
midi.dump(output_path)
|