asigalov61's picture
Update app.py
b3b8a5d verified
# https://huggingface.co/spaces/asigalov61/Bridge-Music-Transformer
import os
import time as reqtime
import datetime
from pytz import timezone
import torch
import spaces
import gradio as gr
from x_transformer_1_23_2 import *
import random
import tqdm
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
in_space = os.getenv("SYSTEM") == "spaces"
# =================================================================================================
@spaces.GPU
def GenerateBridge(input_midi, input_start_note):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
fn = os.path.basename(input_midi.name)
fn1 = fn.split('.')[0]
print('-' * 70)
print('Input file name:', fn)
print('Start note', input_start_note)
print('-' * 70)
print('Loading model...')
SEQ_LEN = 3074
PAD_IDX = 653
DEVICE = 'cpu' # 'cuda'
# instantiate the model
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 1024, depth = 32, heads = 16, attn_flash = True)
)
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
model.load_state_dict(
torch.load('Bridge_Music_Transformer_Trained_Model_30023_steps_0.482_loss_0.8523_acc.pth',
map_location=DEVICE))
print('=' * 70)
model.eval()
if DEVICE == 'cpu':
dtype = torch.bfloat16
else:
dtype = torch.bfloat16
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
print('Done!')
print('=' * 70)
print('Loading MIDI...')
#===============================================================================
# Raw single-track ms score
raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
#===============================================================================
# Enhanced score notes
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
#===============================================================================
# Augmented enhanced score notes
escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))
#=======================================================
# FINAL PROCESSING
melody_chords = []
#=======================================================
# MAIN PROCESSING CYCLE
#=======================================================
pe = escore_notes[0]
for e in escore_notes:
#=======================================================
# Timings...
delta_time = max(0, min(127, e[1]-pe[1]))
# Durations and channels
dur = max(0, min(127, e[2]))
cha = max(0, min(15, e[3]))
# Patches
pat = max(0, min(128, e[6]))
# Pitches
if cha != 9:
ptc = max(1, min(127, e[4]))
else:
ptc = max(1, min(127, e[4]))+128
# Velocities
# Calculating octo-velocity
velocity = max(8, min(127, e[5]))
vel = round(velocity / 15)-1
#=======================================================
# FINAL NOTE SEQ
# Writing final note synchronously
melody_chords.extend([delta_time, dur+128, pat+256, ptc+384, vel+640])
pe = e
#=======================================================
melody_chords = melody_chords[input_start_note*5:]
SEQ_L = 3060
STEP = SEQ_L // 3
score_chunk = melody_chords[:SEQ_L]
td = [649]
td.extend(score_chunk[:STEP])
td += [650]
td.extend(score_chunk[-STEP:])
td += [651]
start_note = score_chunk[:STEP][-5:]
end_note = score_chunk[-STEP:][:5]
print('Done!')
print('=' * 70)
print('Start note', start_note)
print('Etart note', end_note)
print('=' * 70)
print('Generating...')
x = (torch.tensor(td, dtype=torch.long, device=DEVICE)[None, ...])
with ctx:
out = model.generate(x,
1032,
temperature=0.9,
return_prime=False,
verbose=False)
y = out.tolist()
output = []
for i in range(0, len(y[0]), 5):
if len(y[0][i:i+5]) == 5:
output.append(y[0][i:i+5])
print('=' * 70)
print('Done!')
print('=' * 70)
start_note_idx = output.index(start_note)
end_note_idx = len(output)-output[::-1].index(end_note)-1
print('Start note check:', start_note in output, '---', start_note_idx)
print('End note check:',end_note in output, '---', end_note_idx)
#===============================================================================
print('Rendering results...')
data = score_chunk[:STEP] + TMIDIX.flatten(output[:end_note_idx]) + score_chunk[-STEP:]
print('=' * 70)
print('Sample INTs', data[:15])
print('=' * 70)
if len(data) != 0:
song = data
song_f = []
time = 0
dur = 0
vel = 90
pitch = 0
pat = 0
channel = 0
for ss in song:
if 0 < ss < 128:
time += (ss * 32)
if 128 < ss < 256:
dur = (ss-128) * 32
if 256 <= ss <= 384:
pat = (ss-256)
channel = pat // 8
if channel == 9:
channel = 15
if channel == 16:
channel = 9
if 384 < ss < 640:
pitch = (ss-384) % 128
if 640 <= ss < 648:
vel = ((ss-640)+1) * 15
song_f.append(['note', time, dur, channel, pitch, vel, pat])
song_f, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
fn1 = "Bridge-Music-Transformer-Composition"
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
output_signature = 'Bridge Music Transformer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1)
output_midi_summary = str(song_f[:3])
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Bridge Music Transformer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate a seamless bridge between two parts of any composition</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Bridge-Music-Transformer&style=flat)\n\n")
gr.Markdown("## Upload your MIDI or select a sample example MIDI below")
gr.Markdown("### Please note that the MIDI must have at least 615 notes for this demo to work properly")
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
input_start_note = gr.Slider(0, 205, value=0, step=1, label="Start note number")
run_btn = gr.Button("generate", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(GenerateBridge, [input_midi, input_start_note],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
gr.Examples(
[["Sharing The Night Together.kar", 0],
["Sharing The Night Together.kar", 100],
["Deep Relaxation Melody #6.mid", 0],
["Deep Relaxation Melody #6.mid", 100]
],
[input_midi, input_start_note],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
GenerateBridge,
cache_examples=False,
)
app.queue().launch()