Imagen-POP-Music-Medley-Diffusion-Transformer / melody2song_seq2seq_music_transformer.py
asigalov61's picture
Upload 2 files
c9d9ce3 verified
raw
history blame
10.6 kB
# -*- coding: utf-8 -*-
"""Melody2Song_Seq2Seq_Music_Transformer.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1La3iHCib9tluuv4AfsIHCwt1zu0wzl8B
# Melody2Song Seq2Seq Music Transformer (ver. 1.0)
***
Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools
***
WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/
***
#### Project Los Angeles
#### Tegridy Code 2024
***
# (GPU CHECK)
"""
# @title NVIDIA GPU Check
!nvidia-smi
"""# (SETUP ENVIRONMENT)"""
# @title Install requirements
!git clone --depth 1 https://github.com/asigalov61/tegridy-tools
!pip install einops
!pip install torch-summary
!apt install fluidsynth
# Commented out IPython magic to ensure Python compatibility.
# @title Load all needed modules
print('=' * 70)
print('Loading needed modules...')
print('=' * 70)
import os
import pickle
import random
import secrets
import tqdm
import math
import torch
import matplotlib.pyplot as plt
from torchsummary import summary
# %cd /content/tegridy-tools/tegridy-tools/
import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio
# %cd /content/tegridy-tools/tegridy-tools/X-Transformer
from x_transformer_1_23_2 import *
# %cd /content/
import random
from sklearn import metrics
from IPython.display import Audio, display
from huggingface_hub import hf_hub_download
from google.colab import files
print('=' * 70)
print('Done')
print('=' * 70)
print('Torch version:', torch.__version__)
print('=' * 70)
print('Enjoy! :)')
print('=' * 70)
"""# (SETUP DATA AND MODEL)"""
#@title Load Melody2Song Seq2Seq Music Trnasofmer Data and Pre-Trained Model
#@markdown Model precision option
model_precision = "bfloat16" # @param ["bfloat16", "float16"]
plot_tokens_embeddings = True # @param {type:"boolean"}
print('=' * 70)
print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...')
print('=' * 70)
data_path = '/content'
if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'):
print('Data file already exists...')
else:
hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',
repo_type='space',
filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle',
local_dir=data_path,
)
print('=' * 70)
seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')
print('=' * 70)
print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...')
print('Please wait...')
print('=' * 70)
full_path_to_models_dir = "/content"
model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth'
model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name
num_layers = 24
if os.path.isfile(model_path):
print('Model already exists...')
else:
hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',
repo_type='space',
filename=model_checkpoint_file_name,
local_dir=full_path_to_models_dir,
)
print('=' * 70)
print('Instantiating model...')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda'
if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():
dtype = 'bfloat16'
else:
dtype = 'float16'
if model_precision == 'float16':
dtype = 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
SEQ_LEN = 2560
PAD_IDX = 514
# instantiate the model
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True)
)
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
model.cuda()
print('=' * 70)
print('Loading model checkpoint...')
model.load_state_dict(torch.load(model_path))
print('=' * 70)
model.eval()
print('Done!')
print('=' * 70)
print('Model will use', dtype, 'precision...')
print('=' * 70)
# Model stats
print('Model summary...')
summary(model)
if plot_tokens_embeddings:
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()
cos_sim = metrics.pairwise_distances(
tok_emb, metric='cosine'
)
plt.figure(figsize=(7, 7))
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")
"""# (LOAD SEED MELODY)"""
# @title Load desired seed melody
#@markdown NOTE: If custom MIDI file is not provided, sample seed melody will be used instead
full_path_to_custom_seed_melody_MIDI_file = "/content/tegridy-tools/tegridy-tools/seed-melody.mid" # @param {type:"string"}
sample_seed_melody_number = 0 # @param {type:"slider", min:0, max:203664, step:1}
print('=' * 70)
print('Loading seed melody...')
print('=' * 70)
if full_path_to_custom_seed_melody_MIDI_file != '':
#===============================================================================
# Raw single-track ms score
raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file)
#===============================================================================
# Enhanced score notes
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
#===============================================================================
# Augmented enhanced score notes
escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))
cscore = TMIDIX.chordify_score([1000, escore_notes])
fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore])
melody = []
pe = fixed_mel_score[0]
for s in fixed_mel_score:
dtime = max(0, min(127, s[1]-pe[1]))
dur = max(1, min(127, s[2]))
ptc = max(1, min(127, s[4]))
chan = 1
melody.extend([dtime, dur+128, (128 * chan)+ptc+256])
pe = s
if len(melody) >= 192:
melody = [512] + melody[:192] + [513]
else:
mult = math.ceil(192 / len(melody))
melody = melody * mult
melody = [512] + melody[:192] + [513]
print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file)
print('=' * 70)
else:
melody = seed_melodies_data[sample_seed_melody_number]
print('Loaded sample seed melody #', sample_seed_melody_number)
print('=' * 70)
print('Sample melody INTs:', melody[:10])
print('=' * 70)
print('Done!')
print('=' * 70)
"""# (GENERATE)"""
# @title Generate song from melody
melody_MIDI_patch_number = 40 # @param {type:"slider", min:0, max:127, step:1}
accompaniment_MIDI_patch_number = 0 # @param {type:"slider", min:0, max:127, step:1}
number_of_tokens_to_generate = 900 # @param {type:"slider", min:15, max:2354, step:3}
number_of_batches_to_generate = 4 # @param {type:"slider", min:1, max:16, step:1}
top_k_value = 25 # @param {type:"slider", min:1, max:50, step:1}
temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05}
render_MIDI_to_audio = True # @param {type:"boolean"}
print('=' * 70)
print('Melody2Song Seq1Seq Music Transformer Model Generator')
print('=' * 70)
print('Generating...')
print('=' * 70)
model.eval()
torch.cuda.empty_cache()
x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda'))
with ctx:
out = model.generate(x,
number_of_tokens_to_generate,
filter_logits_fn=top_k,
filter_kwargs={'k': top_k_value},
temperature=0.9,
return_prime=False,
verbose=True)
output = out.tolist()
print('=' * 70)
print('Done!')
print('=' * 70)
#======================================================================
print('Rendering results...')
for i in range(number_of_batches_to_generate):
print('=' * 70)
print('Batch #', i)
print('=' * 70)
out1 = output[i]
print('Sample INTs', out1[:12])
print('=' * 70)
if len(out1) != 0:
song = out1
song_f = []
time = 0
dur = 0
vel = 90
pitch = 0
channel = 0
patches = [0] * 16
patches[0] = accompaniment_MIDI_patch_number
patches[3] = melody_MIDI_patch_number
for ss in song:
if 0 < ss < 128:
time += (ss * 32)
if 128 < ss < 256:
dur = (ss-128) * 32
if 256 < ss < 512:
pitch = (ss-256) % 128
channel = (ss-256) // 128
if channel == 1:
channel = 3
vel = 110 + (pitch % 12)
song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number])
else:
vel = 80 + (pitch % 12)
channel = 0
song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number])
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
output_signature = 'Melody2Song Seq2Seq Music Transformer',
output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i),
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
print('=' * 70)
print('Displaying resulting composition...')
print('=' * 70)
fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i)
if render_MIDI_to_audio:
midi_audio = midi_to_colab_audio(fname + '.mid')
display(Audio(midi_audio, rate=16000, normalize=False))
TMIDIX.plot_ms_SONG(song_f, plot_title=fname)
"""# Congrats! You did it! :)"""