MIDI-Renderer / app.py
asigalov61's picture
Update app.py
79a1a3f verified
raw
history blame
10.3 kB
import argparse
import glob
import os.path
import hashlib
import time
import datetime
from pytz import timezone
import gradio as gr
import pickle
import tqdm
import json
import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio
import copy
from collections import Counter
import random
import statistics
import matplotlib.pyplot as plt
#==========================================================================================================
in_space = os.getenv("SYSTEM") == "spaces"
#==========================================================================================================
def render_midi(input_midi, render_type, soundfont_bank, render_sample_rate, melody_patch):
print('*' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = time.time()
print('=' * 70)
print('Loading MIDI...')
fn = os.path.basename(input_midi)
fn1 = fn.split('.')[0]
fdata = open(input_midi, 'rb').read()
input_midi_md5hash = hashlib.md5(fdata).hexdigest()
print('=' * 70)
print('Input MIDI file name:', fn)
print('Input MIDI md5 hash', input_midi_md5hash)
print('Render type:', render_type)
print('Soudnfont bank', soundfont_bank)
print('Audio render sample rate', render_sample_rate)
print('Melody patch', melody_patch)
print('=' * 70)
print('Processing MIDI...Please wait...')
#=======================================================
# START PROCESSING
raw_score = TMIDIX.midi2single_track_ms_score(fdata, recalculate_channels=False)
escore = TMIDIX.advanced_score_processor(raw_score, return_score_analysis=False, return_enhanced_score_notes=True)[0]
first_note_index = raw_score[1].index(escore[0][:6])
for e in escore:
e[1] = int(e[1] / 16)
e[2] = int(e[2] / 16)
# Sorting by patch, pitch, then by start-time
escore.sort(key=lambda x: x[6])
escore.sort(key=lambda x: x[4], reverse=True)
escore.sort(key=lambda x: x[1])
cscore = TMIDIX.chordify_score([1000, escore])
meta_data = raw_score[1][:first_note_index] + [escore[0]] + [escore[-1]] + [raw_score[1][-1]]
print('Done!')
print('=' * 70)
print('Input MIDI metadata:', meta_data)
print('=' * 70)
print('Processing...Please wait...')
if render_type == "Render as-is" or not render_type:
output_score = copy.deepcopy(escore)
elif render_type == "Extract melody":
output_score = copy.deepcopy([c[0] for c in cscore if c[0][3] != 9])
for e in output_score:
e[3] = 0
e[6] = max(0, min(127, melody_patch))
if e[4] < 60:
e[4] = (e[4] % 12) + 60
fixed_score = []
for i in range(len(output_score)-1):
note = output_score[i]
nmt = output_score[i+1][1]
if note[1]+note[2] >= nmt:
note_dur = nmt-note[1]-1
else:
note_dur = note[2]
fixed_score.append([note[0], note[1], note_dur, note[3], note[4], note[5], note[6]])
fixed_score.append(output_score[-1])
output_score = copy.deepcopy(fixed_score)
elif render_type == "Transform":
output_score = copy.deepcopy(escore)
for e in output_score:
if e[3] != 9:
e[4] = (127 - e[4]) - 10
elif render_type == 'Repair':
output_score = []
for c in cscore:
tones_chord = sorted(set([t[4] % 12 for t in c if t[3] != 9]))
drums_events = [t for t in c if t[3] == 9]
if tones_chord:
new_tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord)
if len(c) > 1:
output_score.extend([c[0]])
for cc in c[1:]:
if cc[3] != 9:
if (cc[4] % 12) in new_tones_chord:
output_score.extend([cc])
output_score.extend(drums_events)
else:
output_score.extend([c[0]] + drums_events)
else:
output_score.extend(c)
print('Done processing!')
print('=' * 70)
print('Recalculating timings...')
print('=' * 70)
for e in output_score:
e[1] = e[1] * 16
e[2] = e[2] * 16
print('Done recalculating timings!')
print('=' * 70)
print('Sample output events', output_score[:5])
print('=' * 70)
print('Final processing...')
new_fn = fn1+'.mid'
patches = [-1] * 16
patches[9] = 9
for e in output_score:
if e[3] != 9:
if patches[e[3]] == -1:
patches[e[3]] = e[6]
else:
if patches[e[3]] != e[6]:
if -1 in patches:
patches[patches.index(-1)] = e[6]
else:
patches[-1] = e[6]
patches = [p if p != -1 else 0 for p in patches]
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature = 'Advanced MIDI Renderer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
if soundfont_bank in ["General MIDI", "Nice strings plus orchestra", "Real choir"]:
sf2bank = ["General MIDI", "Nice strings plus orchestra", "Real choir"].index(soundfont_bank)
else:
sf2bank = 0
if render_sample_rate in ["16000", "32000", "44100"]:
srate = int(render_sample_rate)
else:
srate = 16000
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfonts[sf2bank],
sample_rate=srate,
volume_scale=10,
output_for_gradio=True
)
new_md5_hash = hashlib.md5(open(new_fn,'rb').read()).hexdigest()
print('Done!')
print('=' * 70)
#========================================================
output_midi_md5 = str(new_md5_hash)
output_midi_title = str(fn1)
output_midi_summary = str(meta_data)
output_midi = str(new_fn)
output_audio = (srate, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI hash:', output_midi_md5)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (time.time() - start_time), 'sec')
print('*' * 70)
#========================================================
yield output_midi_md5, output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
#==========================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
opt = parser.parse_args()
soundfonts = ["SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2", "Nice-Strings-PlusOrchestra-v1.6.sf2", "KBH-Real-Choir-V2.5.sf2"]
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Advanced MIDI Renderer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Transform and render any MIDI</h1>")
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Advanced-MIDI-Renderer&style=flat)\n\n"
"This is a demo for tegridy-tools\n\n"
"Please see [tegridy-tools](https://github.com/asigalov61/tegridy-tools) GitHub repo for more information\n\n"
)
gr.Markdown("## Upload your MIDI")
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath")
gr.Markdown("## Select desired render type")
render_type = gr.Radio(["Render as-is", "Extract melody", "Transform", "Repair"], label="Render type", value="Render as-is")
gr.Markdown("## Select desired render options")
soundfont_bank = gr.Radio(["General MIDI", "Nice strings plus orchestra", "Real choir"], label="SoundFont bank", value="General MIDI")
render_sample_rate = gr.Radio(["16000", "32000", "44100"], label="MIDI audio render sample rate", value="16000")
melody_patch = gr.Slider(0, 127, value=40, label="Melody MIDI Patch")
submit = gr.Button()
gr.Markdown("## Render results")
output_midi_md5 = gr.Textbox(label="Output MIDI md5 hash")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = submit.click(render_midi, [input_midi, render_type, soundfont_bank, render_sample_rate, melody_patch],
[output_midi_md5, output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)