# https://huggingface.co/spaces/asigalov61/Bridge-Music-Transformer import os import time as reqtime import datetime from pytz import timezone import torch import spaces import gradio as gr from x_transformer_1_23_2 import * import random import tqdm from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt in_space = os.getenv("SYSTEM") == "spaces" # ================================================================================================= @spaces.GPU def GenerateBridge(input_midi, input_start_note): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) fn = os.path.basename(input_midi.name) fn1 = fn.split('.')[0] print('-' * 70) print('Input file name:', fn) print('Start note', input_start_note) print('-' * 70) print('Loading model...') SEQ_LEN = 3074 PAD_IDX = 653 DEVICE = 'cpu' # 'cuda' # instantiate the model model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 1024, depth = 32, heads = 16, attn_flash = True) ) model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX) model.to(DEVICE) print('=' * 70) print('Loading model checkpoint...') model.load_state_dict( torch.load('Bridge_Music_Transformer_Trained_Model_30023_steps_0.482_loss_0.8523_acc.pth', map_location=DEVICE)) print('=' * 70) model.eval() if DEVICE == 'cpu': dtype = torch.bfloat16 else: dtype = torch.bfloat16 ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) print('Done!') print('=' * 70) print('Loading MIDI...') #=============================================================================== # Raw single-track ms score raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name) #=============================================================================== # Enhanced score notes escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] #=============================================================================== # Augmented enhanced score notes escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)) #======================================================= # FINAL PROCESSING melody_chords = [] #======================================================= # MAIN PROCESSING CYCLE #======================================================= pe = escore_notes[0] for e in escore_notes: #======================================================= # Timings... delta_time = max(0, min(127, e[1]-pe[1])) # Durations and channels dur = max(0, min(127, e[2])) cha = max(0, min(15, e[3])) # Patches pat = max(0, min(128, e[6])) # Pitches if cha != 9: ptc = max(1, min(127, e[4])) else: ptc = max(1, min(127, e[4]))+128 # Velocities # Calculating octo-velocity velocity = max(8, min(127, e[5])) vel = round(velocity / 15)-1 #======================================================= # FINAL NOTE SEQ # Writing final note synchronously melody_chords.extend([delta_time, dur+128, pat+256, ptc+384, vel+640]) pe = e #======================================================= melody_chords = melody_chords[input_start_note*5:] SEQ_L = 3060 STEP = SEQ_L // 3 score_chunk = melody_chords[:SEQ_L] td = [649] td.extend(score_chunk[:STEP]) td += [650] td.extend(score_chunk[-STEP:]) td += [651] start_note = score_chunk[:STEP][-5:] end_note = score_chunk[-STEP:][:5] print('Done!') print('=' * 70) print('Start note', start_note) print('Etart note', end_note) print('=' * 70) print('Generating...') x = (torch.tensor(td, dtype=torch.long, device=DEVICE)[None, ...]) with ctx: out = model.generate(x, 1032, temperature=0.9, return_prime=False, verbose=False) y = out.tolist() output = [] for i in range(0, len(y[0]), 5): if len(y[0][i:i+5]) == 5: output.append(y[0][i:i+5]) print('=' * 70) print('Done!') print('=' * 70) start_note_idx = output.index(start_note) end_note_idx = len(output)-output[::-1].index(end_note)-1 print('Start note check:', start_note in output, '---', start_note_idx) print('End note check:',end_note in output, '---', end_note_idx) #=============================================================================== print('Rendering results...') data = score_chunk[:STEP] + TMIDIX.flatten(output[:end_note_idx]) + score_chunk[-STEP:] print('=' * 70) print('Sample INTs', data[:15]) print('=' * 70) if len(data) != 0: song = data song_f = [] time = 0 dur = 0 vel = 90 pitch = 0 pat = 0 channel = 0 for ss in song: if 0 < ss < 128: time += (ss * 32) if 128 < ss < 256: dur = (ss-128) * 32 if 256 <= ss <= 384: pat = (ss-256) channel = pat // 8 if channel == 9: channel = 15 if channel == 16: channel = 9 if 384 < ss < 640: pitch = (ss-384) % 128 if 640 <= ss < 648: vel = ((ss-640)+1) * 15 song_f.append(['note', time, dur, channel, pitch, vel, pat]) song_f, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) fn1 = "Bridge-Music-Transformer-Composition" detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Bridge Music Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(song_f[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("

Bridge Music Transformer

") gr.Markdown("

Generate a seamless bridge between two parts of any composition

") gr.Markdown( "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Bridge-Music-Transformer&style=flat)\n\n") gr.Markdown("## Upload your MIDI or select a sample example MIDI below") gr.Markdown("### Please note that the MIDI must have at least 615 notes for this demo to work properly") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) input_start_note = gr.Slider(0, 205, value=0, step=1, label="Start note number") run_btn = gr.Button("generate", variant="primary") gr.Markdown("## Generation results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Output MIDI summary") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(GenerateBridge, [input_midi, input_start_note], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) gr.Examples( [["Sharing The Night Together.kar", 0], ["Sharing The Night Together.kar", 100], ["Deep Relaxation Melody #6.mid", 0], ["Deep Relaxation Melody #6.mid", 100] ], [input_midi, input_start_note], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot], GenerateBridge, cache_examples=False, ) app.queue().launch()