|
|
|
"""Melody2Song_Seq2Seq_Music_Transformer.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1La3iHCib9tluuv4AfsIHCwt1zu0wzl8B |
|
|
|
# Melody2Song Seq2Seq Music Transformer (ver. 1.0) |
|
|
|
*** |
|
|
|
Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools |
|
|
|
*** |
|
|
|
WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/ |
|
|
|
*** |
|
|
|
#### Project Los Angeles |
|
|
|
#### Tegridy Code 2024 |
|
|
|
*** |
|
|
|
# (GPU CHECK) |
|
""" |
|
|
|
|
|
!nvidia-smi |
|
|
|
"""# (SETUP ENVIRONMENT)""" |
|
|
|
|
|
!git clone --depth 1 https://github.com/asigalov61/tegridy-tools |
|
!pip install einops |
|
!pip install torch-summary |
|
!apt install fluidsynth |
|
|
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Loading needed modules...') |
|
print('=' * 70) |
|
|
|
import os |
|
import pickle |
|
import random |
|
import secrets |
|
import tqdm |
|
import math |
|
import torch |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
from torchsummary import summary |
|
|
|
|
|
|
|
import TMIDIX |
|
from midi_to_colab_audio import midi_to_colab_audio |
|
|
|
|
|
|
|
from x_transformer_1_23_2 import * |
|
|
|
|
|
|
|
import random |
|
|
|
from sklearn import metrics |
|
|
|
from IPython.display import Audio, display |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
from google.colab import files |
|
|
|
print('=' * 70) |
|
print('Done') |
|
print('=' * 70) |
|
print('Torch version:', torch.__version__) |
|
print('=' * 70) |
|
print('Enjoy! :)') |
|
print('=' * 70) |
|
|
|
"""# (SETUP DATA AND MODEL)""" |
|
|
|
|
|
|
|
|
|
|
|
model_precision = "bfloat16" |
|
|
|
plot_tokens_embeddings = True |
|
|
|
print('=' * 70) |
|
print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...') |
|
print('=' * 70) |
|
|
|
data_path = '/content' |
|
|
|
if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'): |
|
print('Data file already exists...') |
|
|
|
else: |
|
hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer', |
|
repo_type='space', |
|
filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle', |
|
local_dir=data_path, |
|
) |
|
|
|
print('=' * 70) |
|
seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data') |
|
|
|
print('=' * 70) |
|
print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...') |
|
print('Please wait...') |
|
print('=' * 70) |
|
|
|
full_path_to_models_dir = "/content" |
|
|
|
model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth' |
|
model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name |
|
num_layers = 24 |
|
if os.path.isfile(model_path): |
|
print('Model already exists...') |
|
|
|
else: |
|
hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer', |
|
repo_type='space', |
|
filename=model_checkpoint_file_name, |
|
local_dir=full_path_to_models_dir, |
|
) |
|
|
|
|
|
print('=' * 70) |
|
print('Instantiating model...') |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' |
|
|
|
if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported(): |
|
dtype = 'bfloat16' |
|
else: |
|
dtype = 'float16' |
|
|
|
if model_precision == 'float16': |
|
dtype = 'float16' |
|
|
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
SEQ_LEN = 2560 |
|
PAD_IDX = 514 |
|
|
|
|
|
|
|
model = TransformerWrapper( |
|
num_tokens = PAD_IDX+1, |
|
max_seq_len = SEQ_LEN, |
|
attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True) |
|
) |
|
|
|
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) |
|
|
|
model.cuda() |
|
print('=' * 70) |
|
|
|
print('Loading model checkpoint...') |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
print('=' * 70) |
|
|
|
model.eval() |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
|
|
print('Model will use', dtype, 'precision...') |
|
print('=' * 70) |
|
|
|
|
|
print('Model summary...') |
|
summary(model) |
|
|
|
if plot_tokens_embeddings: |
|
|
|
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist() |
|
|
|
cos_sim = metrics.pairwise_distances( |
|
tok_emb, metric='cosine' |
|
) |
|
plt.figure(figsize=(7, 7)) |
|
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest") |
|
im_ratio = cos_sim.shape[0] / cos_sim.shape[1] |
|
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04) |
|
plt.xlabel("Position") |
|
plt.ylabel("Position") |
|
plt.tight_layout() |
|
plt.plot() |
|
plt.savefig("/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight") |
|
|
|
"""# (LOAD SEED MELODY)""" |
|
|
|
|
|
|
|
|
|
|
|
full_path_to_custom_seed_melody_MIDI_file = "/content/tegridy-tools/tegridy-tools/seed-melody.mid" |
|
sample_seed_melody_number = 0 |
|
|
|
print('=' * 70) |
|
print('Loading seed melody...') |
|
print('=' * 70) |
|
|
|
if full_path_to_custom_seed_melody_MIDI_file != '': |
|
|
|
|
|
|
|
|
|
raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file) |
|
|
|
|
|
|
|
|
|
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] |
|
|
|
|
|
|
|
|
|
escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)) |
|
|
|
cscore = TMIDIX.chordify_score([1000, escore_notes]) |
|
|
|
fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore]) |
|
|
|
melody = [] |
|
|
|
pe = fixed_mel_score[0] |
|
|
|
for s in fixed_mel_score: |
|
|
|
dtime = max(0, min(127, s[1]-pe[1])) |
|
dur = max(1, min(127, s[2])) |
|
ptc = max(1, min(127, s[4])) |
|
|
|
chan = 1 |
|
|
|
melody.extend([dtime, dur+128, (128 * chan)+ptc+256]) |
|
|
|
pe = s |
|
|
|
if len(melody) >= 192: |
|
melody = [512] + melody[:192] + [513] |
|
|
|
else: |
|
mult = math.ceil(192 / len(melody)) |
|
melody = melody * mult |
|
melody = [512] + melody[:192] + [513] |
|
|
|
print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file) |
|
print('=' * 70) |
|
|
|
else: |
|
melody = seed_melodies_data[sample_seed_melody_number] |
|
print('Loaded sample seed melody #', sample_seed_melody_number) |
|
print('=' * 70) |
|
|
|
print('Sample melody INTs:', melody[:10]) |
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
"""# (GENERATE)""" |
|
|
|
|
|
|
|
melody_MIDI_patch_number = 40 |
|
accompaniment_MIDI_patch_number = 0 |
|
number_of_tokens_to_generate = 900 |
|
number_of_batches_to_generate = 4 |
|
top_k_value = 25 |
|
temperature = 0.9 |
|
render_MIDI_to_audio = True |
|
|
|
print('=' * 70) |
|
print('Melody2Song Seq1Seq Music Transformer Model Generator') |
|
print('=' * 70) |
|
|
|
print('Generating...') |
|
print('=' * 70) |
|
|
|
model.eval() |
|
|
|
torch.cuda.empty_cache() |
|
|
|
x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda')) |
|
|
|
with ctx: |
|
out = model.generate(x, |
|
number_of_tokens_to_generate, |
|
filter_logits_fn=top_k, |
|
filter_kwargs={'k': top_k_value}, |
|
temperature=0.9, |
|
return_prime=False, |
|
verbose=True) |
|
|
|
output = out.tolist() |
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
print('Rendering results...') |
|
|
|
for i in range(number_of_batches_to_generate): |
|
|
|
print('=' * 70) |
|
print('Batch #', i) |
|
print('=' * 70) |
|
|
|
out1 = output[i] |
|
|
|
print('Sample INTs', out1[:12]) |
|
print('=' * 70) |
|
|
|
if len(out1) != 0: |
|
|
|
song = out1 |
|
song_f = [] |
|
|
|
time = 0 |
|
dur = 0 |
|
vel = 90 |
|
pitch = 0 |
|
channel = 0 |
|
|
|
patches = [0] * 16 |
|
patches[0] = accompaniment_MIDI_patch_number |
|
patches[3] = melody_MIDI_patch_number |
|
|
|
for ss in song: |
|
|
|
if 0 < ss < 128: |
|
|
|
time += (ss * 32) |
|
|
|
if 128 < ss < 256: |
|
|
|
dur = (ss-128) * 32 |
|
|
|
if 256 < ss < 512: |
|
|
|
pitch = (ss-256) % 128 |
|
|
|
channel = (ss-256) // 128 |
|
|
|
if channel == 1: |
|
channel = 3 |
|
vel = 110 + (pitch % 12) |
|
song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number]) |
|
|
|
else: |
|
vel = 80 + (pitch % 12) |
|
channel = 0 |
|
song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number]) |
|
|
|
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, |
|
output_signature = 'Melody2Song Seq2Seq Music Transformer', |
|
output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i), |
|
track_name='Project Los Angeles', |
|
list_of_MIDI_patches=patches |
|
) |
|
print('=' * 70) |
|
print('Displaying resulting composition...') |
|
print('=' * 70) |
|
|
|
fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i) |
|
|
|
if render_MIDI_to_audio: |
|
midi_audio = midi_to_colab_audio(fname + '.mid') |
|
display(Audio(midi_audio, rate=16000, normalize=False)) |
|
|
|
TMIDIX.plot_ms_SONG(song_f, plot_title=fname) |
|
|
|
"""# Congrats! You did it! :)""" |