|
|
|
|
|
|
|
|
|
import os |
|
import time as reqtime |
|
import datetime |
|
from pytz import timezone |
|
|
|
import copy |
|
from itertools import groupby |
|
import tqdm |
|
|
|
import spaces |
|
import gradio as gr |
|
|
|
import torch |
|
from x_transformer_1_23_2 import * |
|
import random |
|
|
|
import TMIDIX |
|
|
|
from midi_to_colab_audio import midi_to_colab_audio |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
print('Loading model...') |
|
|
|
SEQ_LEN = 1802 |
|
PAD_IDX = 771 |
|
DEVICE = 'cpu' |
|
|
|
|
|
|
|
model = TransformerWrapper( |
|
num_tokens = PAD_IDX+1, |
|
max_seq_len = SEQ_LEN, |
|
attn_layers = Decoder(dim = 1024, |
|
depth = 8, |
|
heads = 8, |
|
rotary_pos_emb=True, |
|
attn_flash = True |
|
) |
|
) |
|
|
|
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) |
|
|
|
|
|
print('=' * 70) |
|
|
|
print('Loading model checkpoint...') |
|
|
|
model_checkpoint = hf_hub_download(repo_id='asigalov61/Score-2-Performance-Transformer', |
|
filename='Score_2_Performance_Transformer_Final_Small_Trained_Model_4496_steps_1.5185_loss_0.5589_acc.pth' |
|
) |
|
|
|
model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True)) |
|
|
|
model = torch.compile(model, mode='max-autotune') |
|
|
|
dtype = torch.bfloat16 |
|
|
|
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) |
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
def load_midi(midi_file): |
|
|
|
print('Loading MIDI...') |
|
|
|
raw_score = TMIDIX.midi2single_track_ms_score(midi_file) |
|
|
|
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True) |
|
|
|
if escore_notes[0]: |
|
|
|
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16) |
|
|
|
pe = escore_notes[0] |
|
|
|
melody_chords = [] |
|
|
|
seen = [] |
|
|
|
for e in escore_notes: |
|
|
|
if e[3] != 9: |
|
|
|
|
|
|
|
dtime = max(0, min(255, e[1]-pe[1])) |
|
|
|
if dtime != 0: |
|
seen = [] |
|
|
|
|
|
dur = max(1, min(255, e[2])) |
|
|
|
|
|
ptc = max(1, min(127, e[4])) |
|
|
|
vel = max(1, min(127, e[5])) |
|
|
|
if ptc not in seen: |
|
|
|
melody_chords.append([dtime, dur, ptc, vel]) |
|
|
|
seen.append(ptc) |
|
|
|
pe = e |
|
|
|
print('=' * 70) |
|
print('Number of notes in a composition:', len(melody_chords)) |
|
print('=' * 70) |
|
|
|
src_melody_chords_f = [] |
|
|
|
for i in range(0, len(melody_chords), 150): |
|
|
|
chunk = melody_chords[i:i+300] |
|
|
|
src = [] |
|
|
|
for mm in chunk: |
|
src.append([mm[0], mm[2]+256, mm[1]+384, mm[3]+640]) |
|
|
|
clen = len(src) |
|
|
|
if clen < 300: |
|
|
|
chunk_mult = (300 // clen) + 1 |
|
|
|
src += src * chunk_mult |
|
|
|
src_melody_chords_f.append([clen, src[:300]]) |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
print('Number of composition chunks:', len(src_melody_chords_f)) |
|
print('=' * 70) |
|
|
|
return src_melody_chords_f |
|
|
|
|
|
|
|
@spaces.GPU |
|
def Convert_Score_to_Performance(input_midi, |
|
input_midi_type, |
|
input_conv_type, |
|
input_number_prime_notes, |
|
input_number_conv_notes, |
|
input_model_dur_top_k, |
|
input_model_dur_temperature, |
|
input_model_vel_temperature |
|
): |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
start_time = reqtime.time() |
|
print('=' * 70) |
|
|
|
fn = os.path.basename(input_midi) |
|
fn1 = fn.split('.')[0] |
|
|
|
print('=' * 70) |
|
print('Requested settings:') |
|
print('=' * 70) |
|
print('Input MIDI file name:', fn) |
|
print('Input MIDI type:', input_midi_type) |
|
print('Conversion type:', input_conv_type) |
|
print('Number of prime notes:', input_number_prime_notes) |
|
print('Number of notes to convert:', input_number_conv_notes) |
|
print('Model durations sampling top value:', input_model_dur_top_k) |
|
print('Model durations temperature:', input_model_dur_temperature) |
|
print('Model velocities temperature:', input_model_vel_temperature) |
|
|
|
print('=' * 70) |
|
|
|
|
|
|
|
src_melody_chords_f = load_midi(input_midi.name) |
|
|
|
|
|
|
|
print('Sample output events', src_melody_chords_f[0][1][:3]) |
|
print('=' * 70) |
|
print('Generating...') |
|
|
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
|
|
|
|
num_prime_notes = input_number_prime_notes |
|
dur_top_k = input_model_dur_top_k |
|
|
|
dur_temperature = input_model_dur_temperature |
|
vel_temperature = input_model_vel_temperature |
|
|
|
|
|
|
|
if input_midi_type == 'Score': |
|
|
|
dur_top_k = 1 |
|
dur_temperature = 1.1 |
|
vel_temperature = 1.5 |
|
|
|
elif input_midi_type == 'Performance': |
|
|
|
dur_top_k = 100 |
|
dur_temperature = 1.5 |
|
vel_temperature = 1.9 |
|
|
|
else: |
|
|
|
dur_top_k = input_model_dur_top_k |
|
|
|
dur_temperature = input_model_dur_temperature |
|
vel_temperature = input_model_vel_temperature |
|
|
|
final_song = [] |
|
|
|
for cc, (song_chunk_len, song_chunk) in enumerate(src_melody_chords_f): |
|
|
|
print('=' * 70) |
|
print('Rendering song chunk #', cc) |
|
print('=' * 70) |
|
|
|
|
|
|
|
song = [768] |
|
|
|
if cc == 0: |
|
|
|
for m in song_chunk: |
|
song.extend(m[:2]) |
|
|
|
song.append(769) |
|
|
|
sidx = 0 |
|
eidx = 300 |
|
|
|
else: |
|
for m in song_chunk[:150]: |
|
psrc.extend(m[:2]) |
|
|
|
psrc.append(769) |
|
|
|
song = copy.deepcopy(psrc + ptrg) |
|
|
|
sidx = 150 |
|
eidx = 300 |
|
|
|
|
|
|
|
for i in tqdm.tqdm(range(sidx, eidx)): |
|
|
|
song.extend(song_chunk[i][:2]) |
|
|
|
if 'Durations' in input_conv_type: |
|
|
|
if i < num_prime_notes and cc == 0: |
|
song.append(song_chunk[i][2]) |
|
|
|
else: |
|
|
|
|
|
|
|
x = torch.LongTensor(song).cuda() |
|
|
|
y = 0 |
|
|
|
while not 384 < y < 640: |
|
|
|
with ctx: |
|
out = model.generate(x, |
|
1, |
|
temperature=dur_temperature, |
|
filter_logits_fn=top_k, |
|
filter_kwargs={'k': dur_top_k}, |
|
return_prime=False, |
|
verbose=False) |
|
|
|
y = out.tolist()[0][0] |
|
|
|
song.append(y) |
|
|
|
else: |
|
song.append(song_chunk[i][2]) |
|
|
|
|
|
|
|
if 'Velocities' in input_conv_type: |
|
|
|
|
|
if i < num_prime_notes and cc == 0: |
|
song.append(song_chunk[i][3]) |
|
|
|
else: |
|
|
|
|
|
|
|
x = torch.LongTensor(song).cuda() |
|
|
|
y = 0 |
|
|
|
while not 640 < y < 768: |
|
|
|
with ctx: |
|
out = model.generate(x, |
|
1, |
|
temperature=vel_temperature, |
|
return_prime=False, |
|
verbose=False) |
|
|
|
y = out.tolist()[0][0] |
|
|
|
song.append(y) |
|
|
|
else: |
|
song.append(song_chunk[i][3]) |
|
|
|
|
|
|
|
if cc == 0: |
|
final_song.extend(song[602:][:(song_chunk_len * 4)]) |
|
|
|
else: |
|
final_song.extend(song[602:][600:(song_chunk_len * 4)]) |
|
|
|
psrc = copy.deepcopy(song[1:301]) |
|
ptrg = copy.deepcopy(song[602:][:600]) |
|
|
|
|
|
|
|
if len(final_song) >= input_number_conv_notes * 4: |
|
break |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
print('Rendering results...') |
|
|
|
print('=' * 70) |
|
print('Sample INTs', final_song[:15]) |
|
print('=' * 70) |
|
|
|
song_f = [] |
|
|
|
if len(final_song) != 0: |
|
|
|
time = 0 |
|
dur = 0 |
|
vel = 90 |
|
pitch = 60 |
|
channel = 0 |
|
patch = 0 |
|
|
|
patches = [0] * 16 |
|
|
|
for ss in final_song: |
|
|
|
if 0 <= ss < 256: |
|
|
|
time += ss * 16 |
|
|
|
if 256 <= ss < 384: |
|
|
|
pitch = ss-256 |
|
|
|
if 384 <= ss < 640: |
|
|
|
dur = (ss-384) * 16 |
|
|
|
if 640 <= ss < 768: |
|
|
|
vel = (ss-640) |
|
|
|
song_f.append(['note', time, dur, channel, pitch, vel, patch]) |
|
|
|
fn1 = "Score-2-Performance-Transformer-Composition" |
|
|
|
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, |
|
output_signature = 'Score 2 Performance Transformer', |
|
output_file_name = fn1, |
|
track_name='Project Los Angeles', |
|
list_of_MIDI_patches=patches |
|
) |
|
|
|
new_fn = fn1+'.mid' |
|
|
|
|
|
audio = midi_to_colab_audio(new_fn, |
|
soundfont_path=soundfont, |
|
sample_rate=16000, |
|
volume_scale=10, |
|
output_for_gradio=True |
|
) |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
output_midi_title = str(fn1) |
|
output_midi_summary = str(song_f[:3]) |
|
output_midi = str(new_fn) |
|
output_audio = (16000, audio) |
|
|
|
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) |
|
|
|
print('Output MIDI file name:', output_midi) |
|
print('Output MIDI title:', output_midi_title) |
|
print('Output MIDI summary:', output_midi_summary) |
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
print('-' * 70) |
|
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('-' * 70) |
|
print('Req execution time:', (reqtime.time() - start_time), 'sec') |
|
|
|
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
PDT = timezone('US/Pacific') |
|
|
|
print('=' * 70) |
|
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('=' * 70) |
|
|
|
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" |
|
|
|
app = gr.Blocks() |
|
|
|
with app: |
|
|
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Score 2 Performance Transformer</h1>") |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Convert any MIDI score to a nice performance</h1>") |
|
|
|
gr.Markdown("## Upload your MIDI or select a sample example MIDI below") |
|
|
|
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) |
|
|
|
gr.Markdown("## Select MIDI type") |
|
|
|
input_midi_type = gr.Radio(["Score", "Performance", "Custom"], |
|
value="Score", |
|
label="Input MIDI type", |
|
info="Select 'Custom' option to enable model top_k and temperature settings below" |
|
) |
|
gr.Markdown("## Select conversion type") |
|
|
|
input_conv_type = gr.Radio(["Durations and Velocities", "Durations", "Velocities"], |
|
value="Durations and Velocities", |
|
label="Conversion type" |
|
) |
|
|
|
gr.Markdown("## Conversion options") |
|
|
|
input_number_prime_notes = gr.Slider(0, 512, value=0, step=8, label="Number of prime notes") |
|
input_number_conv_notes = gr.Slider(8, 2048, value=512, step=8, label="Number of notes to convert") |
|
|
|
gr.Markdown("## Custom MIDI type model options") |
|
|
|
input_model_dur_top_k = gr.Slider(1, 100, value=1, step=1, label="Model sampling top k value for durations") |
|
input_model_dur_temperature = gr.Slider(0.5, 1.5, value=1.1, step=0.05, label="Model temperature for durations") |
|
input_model_vel_temperature = gr.Slider(0.5, 1.5, value=1.5, step=0.05, label="Model temperature for velocities") |
|
|
|
run_btn = gr.Button("convert", variant="primary") |
|
|
|
gr.Markdown("## Generation results") |
|
|
|
output_midi_title = gr.Textbox(label="Output MIDI title") |
|
output_midi_summary = gr.Textbox(label="Output MIDI summary") |
|
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") |
|
output_plot = gr.Plot(label="Output MIDI score plot") |
|
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) |
|
|
|
run_event = run_btn.click(Convert_Score_to_Performance, [input_midi, |
|
input_midi_type, |
|
input_conv_type, |
|
input_number_prime_notes, |
|
input_number_conv_notes, |
|
input_model_dur_top_k, |
|
input_model_dur_temperature, |
|
input_model_vel_temperature |
|
], |
|
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) |
|
|
|
gr.Examples( |
|
[["asap_midi_score_21.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], |
|
["asap_midi_score_45.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], |
|
["asap_midi_score_69.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], |
|
["asap_midi_score_118.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], |
|
["asap_midi_score_167.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], |
|
], |
|
[input_midi, |
|
input_midi_type, |
|
input_conv_type, |
|
input_number_prime_notes, |
|
input_number_conv_notes, |
|
input_model_dur_top_k, |
|
input_model_dur_temperature, |
|
input_model_vel_temperature |
|
], |
|
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot], |
|
Convert_Score_to_Performance |
|
) |
|
|
|
app.queue().launch() |