asigalov61's picture
Update app.py
a15bc89 verified
import os.path
import time as reqtime
import datetime
from pytz import timezone
import torch
import spaces
import gradio as gr
from x_transformer_1_23_2 import *
import random
import tqdm
import pprint
import io
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
in_space = os.getenv("SYSTEM") == "spaces"
# =================================================================================================
@spaces.GPU
def GenerateMusic():
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('Loading model...')
SEQ_LEN = 2048
PAD_IDX = 780
DEVICE = 'cuda' # 'cuda'
# instantiate the model
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 1024, depth = 32, heads = 16, attn_flash = True)
)
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
model.load_state_dict(
torch.load('Descriptive_Music_Transformer_Trained_Model_20631_steps_0.3218_loss_0.8947_acc.pth',
map_location=DEVICE))
print('=' * 70)
model.eval()
if DEVICE == 'cpu':
dtype = torch.bfloat16
else:
dtype = torch.bfloat16
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
print('Done!')
print('=' * 70)
input_num_tokens = 1024+512
print('-' * 70)
#===============================================================================
print('=' * 70)
print('Loading helper functions...')
def txt2tokens(txt):
return [ord(char)+648 if 0 < ord(char) < 128 else 0+648 for char in txt.lower()]
def tokens2txt(tokens):
return [chr(tok-648) for tok in tokens if 0+648 < tok < 128+648 ]
def pprint_to_string(obj, compact=True):
output = io.StringIO()
pprint.pprint(obj, stream=output, compact=compact)
return output.getvalue()
print('=' * 70)
print('Generating...')
#@title Standard Text-to-Music Generator
#@markdown Generation settings
number_of_tokens_to_generate = input_num_tokens
number_of_batches_to_generate = 1 #@param {type:"slider", min:1, max:16, step:1}
temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05}
print('=' * 70)
print('Descriptive Music Transformer Model Generator')
print('=' * 70)
outy = [777]
torch.cuda.empty_cache()
inp = [outy] * number_of_batches_to_generate
inp = torch.LongTensor(inp).cuda()
with ctx:
out = model.generate(inp,
number_of_tokens_to_generate,
temperature=temperature,
return_prime=True,
verbose=False)
out0 = out.tolist()
print('=' * 70)
print('Done!')
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
out1 = out0[0]
print('Sample INTs', out1[:12])
print('=' * 70)
descr = ''.join(tokens2txt(out1)).split('. ')
descr1 = descr[0].capitalize()
descr2 = descr[1].capitalize()
generated_song_description = str(pprint_to_string(descr1).replace(" '", "").replace("'", "")[1:-2] +'.\n\n' + pprint_to_string(descr2).replace("'", "").replace(" '", "")[1:-2])
if len(out1) != 0:
song = out1
song_f = []
time = 0
dur = 0
vel = 90
pitch = 0
pat = 0
channel = 0
for ss in song:
if 0 < ss < 128:
time += (ss * 32)
if 128 < ss < 256:
dur = (ss-128) * 32
if 256 <= ss <= 384:
pat = (ss-256)
channel = pat // 8
if channel == 9:
channel = 15
if channel == 16:
channel = 9
if 384 < ss < 640:
pitch = (ss-384) % 128
if 640 <= ss < 648:
vel = ((ss-640)+1) * 15
song_f.append(['note', time, dur, channel, pitch, vel, pat])
song_f, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
fn1 = "Descriptive-Music-Transformer-Composition"
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
output_signature = 'Descriptive Music Transformer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1).replace('-', ' ')
output_midi_summary = str(generated_song_description)
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Descriptive Music Transformer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>A music transformer that describes music it generates</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Descriptive-Music-Transformer&style=flat)\n\n"
'This is a demo for Annotated MIDI Dataset.\n\n'
"Check out [Annotated MIDI Dataset](https://huggingface.co/datasets/asigalov61/Annotated-MIDI-Dataset) on Hugging Face!\n\n"
)
run_btn = gr.Button("generate", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Generated music description")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(GenerateMusic, outputs=[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
app.queue().launch()