PicturesOfMIDI / pom /transpose_midi.py
drscotthawley's picture
dummy dir pom
b1e308f
#! /usr/bin/env python3
# This script takes a directory of MIDI files and transposes them between -max and +max semitones
# It was intended for the P909 dataset
# It will create a directory of images for each MIDI file, where each image is a frame of the MIDI file
import os
import sys
import pretty_midi
from multiprocessing import Pool, cpu_count, set_start_method
from tqdm import tqdm
from control_toys.data import fast_scandir
from functools import partial
import argparse
from copy import deepcopy
def transpose_midi(args, midi_file):
#print("midi_file = ",midi_file)
if '_transposed_' in midi_file: return # don't transpose files that have already been transposed
midi_save = pretty_midi.PrettyMIDI(midi_file)
for transpose_by in range(-args.transpose, args.transpose+1):
midi = deepcopy(midi_save)
if transpose_by == 0: continue
for instrument in midi.instruments:
for note in instrument.notes:
note.pitch += transpose_by
midi.write(midi_file.replace('.mid', f'_transposed_{transpose_by}.mid'))
return
if __name__ == '__main__':
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('-t','--transpose', type=int, default=0, help='transpose by this maximum number of semitones (+/-)')
p.add_argument('--start-method', type=str, default='fork',
choices=['fork', 'forkserver', 'spawn'],
help='the multiprocessing start method')
p.add_argument('--skip-versions', default=True, help='skip extra versions of the same song')
p.add_argument("midi_dirs", nargs='+', help="directories containing MIDI files")
args = p.parse_args()
print("args = ",args)
set_start_method(args.start_method)
midi_dirs = args.midi_dirs
if os.path.isdir(midi_dirs[0]):
midi_files = []
for mdir in midi_dirs:
m_subdirs, mf = fast_scandir(mdir, ['mid', 'midi'])
if mf != []: midi_files = midi_files + mf
elif os.path.isfile(midi_dirs[0]):
midi_files = midi_dirs
if args.skip_versions:
midi_files = [f for f in midi_files if '/versions/' not in f]
#print("midi_files = ",midi_files) # just a check for debugging
# transpose all the midi files
if args.transpose != 0:
transpose_one = partial(transpose_midi, args)
with Pool(cpu_count()) as p:
list(tqdm(p.imap(transpose_one, midi_files), total=len(midi_files), desc='Transposing MIDI files'))
print("Finished")