asigalov61's picture
Update app.py
6339a18 verified
raw
history blame
6.82 kB
#=======================================================================================
# https://huggingface.co/spaces/asigalov61/Imagen-POP-Music-Medley-Diffusion-Transformer
#=======================================================================================
import os
import time as reqtime
import datetime
from pytz import timezone
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
import spaces
import gradio as gr
import numpy as np
import random
import tqdm
import TMIDIX
import TPLOTS
from midi_to_colab_audio import midi_to_colab_audio
# =================================================================================================
@spaces.GPU
def Generate_POP_Medley(input_num_medley_comps):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
print('Loading model...')
DIM = 64
CHANS = 1
TSTEPS = 1000
DEVICE = 'cuda' # 'cpu'
unet = Unet(
dim = DIM,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
channels=CHANS,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
channels=CHANS,
image_sizes = 128,
timesteps = TSTEPS
)
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training
).to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckpt')
print('Done!')
print('=' * 70)
print('Req number of medley compositions:', input_num_medley_comps)
print('=' * 70)
print('Generating...')
images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True)
threshold = 128
imgs_array = []
for i in images:
arr = np.array(i)
farr = np.where(arr < threshold, 0, 1)
imgs_array.append(farr)
print('Done!')
print('=' * 70)
#===============================================================================
print('Converting images to scores...')
medley_compositions_escores = []
for i in imgs_array:
bmatrix = TPLOTS.images_to_binary_matrix([i])
score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix)
medley_compositions_escores.append(score)
print('Done!')
print('=' * 70)
print('Creating medley score...')
medley_labels = ['Composition #' + str(i+1) for i in range(len(medley_compositions_escores))]
medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels)
#===============================================================================
print('Rendering results...')
print('=' * 70)
print('Sample INTs', medley_escore[:15])
print('=' * 70)
fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition"
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
output_signature = 'Imagen POP Music Medley',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1)
output_midi_summary = str(song_f[:3])
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Imagen POP Music Medley Diffusion Transformer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate unique POP music medleys with Imagen diffusion transformer</h1>")
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Imagen-POP-Music-Medley-Diffusion-Transformer&style=flat)\n\n"
"This is a demo for MIDI Images dataset\n\n"
"Please see [MIDI Images](https://huggingface.co/datasets/asigalov61/MIDI-Images) Hugging Face repo for more information\n\n"
)
input_num_medley_comps = gr.Slider(1, 12, value=8, step=1, label="Number of medley compositions")
run_btn = gr.Button("Generate POP Medley", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(Generate_POP_Medley, [input_num_medley_comps],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
app.queue().launch()