asigalov61 commited on
Commit
ef7a679
1 Parent(s): 6f6e0e1

Delete Melody2Song_Seq2Seq_Music_Transformer.ipynb

Browse files
Melody2Song_Seq2Seq_Music_Transformer.ipynb DELETED
@@ -1,523 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "VGrGd6__l5ch"
7
- },
8
- "source": [
9
- "# Melody2Song Seq2Seq Music Transformer (ver. 1.0)\n",
10
- "\n",
11
- "***\n",
12
- "\n",
13
- "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
14
- "\n",
15
- "***\n",
16
- "\n",
17
- "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
18
- "\n",
19
- "***\n",
20
- "\n",
21
- "#### Project Los Angeles\n",
22
- "\n",
23
- "#### Tegridy Code 2024\n",
24
- "\n",
25
- "***"
26
- ]
27
- },
28
- {
29
- "cell_type": "markdown",
30
- "metadata": {
31
- "id": "shLrgoXdl5cj"
32
- },
33
- "source": [
34
- "# (GPU CHECK)"
35
- ]
36
- },
37
- {
38
- "cell_type": "code",
39
- "execution_count": null,
40
- "metadata": {
41
- "id": "X3rABEpKCO02",
42
- "cellView": "form"
43
- },
44
- "outputs": [],
45
- "source": [
46
- "# @title NVIDIA GPU Check\n",
47
- "!nvidia-smi"
48
- ]
49
- },
50
- {
51
- "cell_type": "markdown",
52
- "metadata": {
53
- "id": "0RcVC4btl5ck"
54
- },
55
- "source": [
56
- "# (SETUP ENVIRONMENT)"
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": null,
62
- "metadata": {
63
- "id": "viHgEaNACPTs",
64
- "cellView": "form"
65
- },
66
- "outputs": [],
67
- "source": [
68
- "# @title Install requirements\n",
69
- "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools\n",
70
- "!pip install einops\n",
71
- "!pip install torch-summary\n",
72
- "!apt install fluidsynth"
73
- ]
74
- },
75
- {
76
- "cell_type": "code",
77
- "execution_count": null,
78
- "metadata": {
79
- "id": "DzCOZU_gBiQV",
80
- "cellView": "form"
81
- },
82
- "outputs": [],
83
- "source": [
84
- "# @title Load all needed modules\n",
85
- "\n",
86
- "print('=' * 70)\n",
87
- "print('Loading needed modules...')\n",
88
- "print('=' * 70)\n",
89
- "\n",
90
- "import os\n",
91
- "import pickle\n",
92
- "import random\n",
93
- "import secrets\n",
94
- "import tqdm\n",
95
- "import math\n",
96
- "import torch\n",
97
- "\n",
98
- "import matplotlib.pyplot as plt\n",
99
- "\n",
100
- "from torchsummary import summary\n",
101
- "\n",
102
- "%cd /content/tegridy-tools/tegridy-tools/\n",
103
- "\n",
104
- "import TMIDIX\n",
105
- "from midi_to_colab_audio import midi_to_colab_audio\n",
106
- "\n",
107
- "%cd /content/tegridy-tools/tegridy-tools/X-Transformer\n",
108
- "\n",
109
- "from x_transformer_1_23_2 import *\n",
110
- "\n",
111
- "%cd /content/\n",
112
- "\n",
113
- "import random\n",
114
- "\n",
115
- "from sklearn import metrics\n",
116
- "\n",
117
- "from IPython.display import Audio, display\n",
118
- "\n",
119
- "from huggingface_hub import hf_hub_download\n",
120
- "\n",
121
- "from google.colab import files\n",
122
- "\n",
123
- "print('=' * 70)\n",
124
- "print('Done')\n",
125
- "print('=' * 70)\n",
126
- "print('Torch version:', torch.__version__)\n",
127
- "print('=' * 70)\n",
128
- "print('Enjoy! :)')\n",
129
- "print('=' * 70)"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "source": [
135
- "# (SETUP DATA AND MODEL)"
136
- ],
137
- "metadata": {
138
- "id": "SQ1_7P4bLdtB"
139
- }
140
- },
141
- {
142
- "cell_type": "code",
143
- "source": [
144
- "#@title Load Melody2Song Seq2Seq Music Trnasofmer Data and Pre-Trained Model\n",
145
- "\n",
146
- "#@markdown Model precision option\n",
147
- "\n",
148
- "model_precision = \"bfloat16\" # @param [\"bfloat16\", \"float16\"]\n",
149
- "\n",
150
- "plot_tokens_embeddings = True # @param {type:\"boolean\"}\n",
151
- "\n",
152
- "print('=' * 70)\n",
153
- "print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...')\n",
154
- "print('=' * 70)\n",
155
- "\n",
156
- "data_path = '/content'\n",
157
- "\n",
158
- "if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'):\n",
159
- " print('Data file already exists...')\n",
160
- "\n",
161
- "else:\n",
162
- " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n",
163
- " repo_type='space',\n",
164
- " filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle',\n",
165
- " local_dir=data_path,\n",
166
- " )\n",
167
- "\n",
168
- "print('=' * 70)\n",
169
- "seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')\n",
170
- "\n",
171
- "print('=' * 70)\n",
172
- "print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...')\n",
173
- "print('Please wait...')\n",
174
- "print('=' * 70)\n",
175
- "\n",
176
- "full_path_to_models_dir = \"/content\"\n",
177
- "\n",
178
- "model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth'\n",
179
- "model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name\n",
180
- "num_layers = 24\n",
181
- "if os.path.isfile(model_path):\n",
182
- " print('Model already exists...')\n",
183
- "\n",
184
- "else:\n",
185
- " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n",
186
- " repo_type='space',\n",
187
- " filename=model_checkpoint_file_name,\n",
188
- " local_dir=full_path_to_models_dir,\n",
189
- " )\n",
190
- "\n",
191
- "\n",
192
- "print('=' * 70)\n",
193
- "print('Instantiating model...')\n",
194
- "\n",
195
- "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
196
- "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
197
- "device_type = 'cuda'\n",
198
- "\n",
199
- "if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():\n",
200
- " dtype = 'bfloat16'\n",
201
- "else:\n",
202
- " dtype = 'float16'\n",
203
- "\n",
204
- "if model_precision == 'float16':\n",
205
- " dtype = 'float16'\n",
206
- "\n",
207
- "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n",
208
- "ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)\n",
209
- "\n",
210
- "SEQ_LEN = 2560\n",
211
- "PAD_IDX = 514\n",
212
- "\n",
213
- "# instantiate the model\n",
214
- "\n",
215
- "model = TransformerWrapper(\n",
216
- " num_tokens = PAD_IDX+1,\n",
217
- " max_seq_len = SEQ_LEN,\n",
218
- " attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True)\n",
219
- ")\n",
220
- "\n",
221
- "model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)\n",
222
- "\n",
223
- "model.cuda()\n",
224
- "print('=' * 70)\n",
225
- "\n",
226
- "print('Loading model checkpoint...')\n",
227
- "\n",
228
- "model.load_state_dict(torch.load(model_path))\n",
229
- "print('=' * 70)\n",
230
- "\n",
231
- "model.eval()\n",
232
- "\n",
233
- "print('Done!')\n",
234
- "print('=' * 70)\n",
235
- "\n",
236
- "print('Model will use', dtype, 'precision...')\n",
237
- "print('=' * 70)\n",
238
- "\n",
239
- "# Model stats\n",
240
- "print('Model summary...')\n",
241
- "summary(model)\n",
242
- "\n",
243
- "if plot_tokens_embeddings:\n",
244
- "\n",
245
- " tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n",
246
- "\n",
247
- " cos_sim = metrics.pairwise_distances(\n",
248
- " tok_emb, metric='cosine'\n",
249
- " )\n",
250
- " plt.figure(figsize=(7, 7))\n",
251
- " plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n",
252
- " im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n",
253
- " plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n",
254
- " plt.xlabel(\"Position\")\n",
255
- " plt.ylabel(\"Position\")\n",
256
- " plt.tight_layout()\n",
257
- " plt.plot()\n",
258
- " plt.savefig(\"/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")"
259
- ],
260
- "metadata": {
261
- "cellView": "form",
262
- "id": "z7QLJ6FajxPA"
263
- },
264
- "execution_count": null,
265
- "outputs": []
266
- },
267
- {
268
- "cell_type": "markdown",
269
- "source": [
270
- "# (LOAD SEED MELODY)"
271
- ],
272
- "metadata": {
273
- "id": "NdJ1_A8gNoV3"
274
- }
275
- },
276
- {
277
- "cell_type": "code",
278
- "execution_count": null,
279
- "metadata": {
280
- "id": "AIvb6MmSO9R3",
281
- "cellView": "form"
282
- },
283
- "outputs": [],
284
- "source": [
285
- "# @title Load desired seed melody\n",
286
- "\n",
287
- "#@markdown NOTE: If custom MIDI file is not provided, sample seed melody will be used instead\n",
288
- "\n",
289
- "full_path_to_custom_seed_melody_MIDI_file = \"/content/tegridy-tools/tegridy-tools/seed-melody.mid\" # @param {type:\"string\"}\n",
290
- "sample_seed_melody_number = 0 # @param {type:\"slider\", min:0, max:203664, step:1}\n",
291
- "\n",
292
- "print('=' * 70)\n",
293
- "print('Loading seed melody...')\n",
294
- "print('=' * 70)\n",
295
- "\n",
296
- "if full_path_to_custom_seed_melody_MIDI_file != '':\n",
297
- "\n",
298
- " #===============================================================================\n",
299
- " # Raw single-track ms score\n",
300
- "\n",
301
- " raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file)\n",
302
- "\n",
303
- " #===============================================================================\n",
304
- " # Enhanced score notes\n",
305
- "\n",
306
- " escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]\n",
307
- "\n",
308
- " #===============================================================================\n",
309
- " # Augmented enhanced score notes\n",
310
- "\n",
311
- " escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))\n",
312
- "\n",
313
- " cscore = TMIDIX.chordify_score([1000, escore_notes])\n",
314
- "\n",
315
- " fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore])\n",
316
- "\n",
317
- " melody = []\n",
318
- "\n",
319
- " pe = fixed_mel_score[0]\n",
320
- "\n",
321
- " for s in fixed_mel_score:\n",
322
- "\n",
323
- " dtime = max(0, min(127, s[1]-pe[1]))\n",
324
- " dur = max(1, min(127, s[2]))\n",
325
- " ptc = max(1, min(127, s[4]))\n",
326
- "\n",
327
- " chan = 1\n",
328
- "\n",
329
- " melody.extend([dtime, dur+128, (128 * chan)+ptc+256])\n",
330
- "\n",
331
- " pe = s\n",
332
- "\n",
333
- " if len(melody) >= 192:\n",
334
- " melody = [512] + melody[:192] + [513]\n",
335
- "\n",
336
- " else:\n",
337
- " mult = math.ceil(192 / len(melody))\n",
338
- " melody = melody * mult\n",
339
- " melody = [512] + melody[:192] + [513]\n",
340
- "\n",
341
- " print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file)\n",
342
- " print('=' * 70)\n",
343
- "\n",
344
- "else:\n",
345
- " melody = seed_melodies_data[sample_seed_melody_number]\n",
346
- " print('Loaded sample seed melody #', sample_seed_melody_number)\n",
347
- " print('=' * 70)\n",
348
- "\n",
349
- "print('Sample melody INTs:', melody[:10])\n",
350
- "print('=' * 70)\n",
351
- "print('Done!')\n",
352
- "print('=' * 70)"
353
- ]
354
- },
355
- {
356
- "cell_type": "markdown",
357
- "metadata": {
358
- "id": "feXay_Ed7mG5"
359
- },
360
- "source": [
361
- "# (GENERATE)"
362
- ]
363
- },
364
- {
365
- "cell_type": "code",
366
- "execution_count": null,
367
- "metadata": {
368
- "id": "naf65RxUXwDg",
369
- "cellView": "form"
370
- },
371
- "outputs": [],
372
- "source": [
373
- "# @title Generate song from melody\n",
374
- "\n",
375
- "melody_MIDI_patch_number = 40 # @param {type:\"slider\", min:0, max:127, step:1}\n",
376
- "accompaniment_MIDI_patch_number = 0 # @param {type:\"slider\", min:0, max:127, step:1}\n",
377
- "number_of_tokens_to_generate = 900 # @param {type:\"slider\", min:15, max:2354, step:3}\n",
378
- "number_of_batches_to_generate = 4 # @param {type:\"slider\", min:1, max:16, step:1}\n",
379
- "top_k_value = 25 # @param {type:\"slider\", min:1, max:50, step:1}\n",
380
- "temperature = 0.9 # @param {type:\"slider\", min:0.1, max:1, step:0.05}\n",
381
- "render_MIDI_to_audio = True # @param {type:\"boolean\"}\n",
382
- "\n",
383
- "print('=' * 70)\n",
384
- "print('Melody2Song Seq1Seq Music Transformer Model Generator')\n",
385
- "print('=' * 70)\n",
386
- "\n",
387
- "print('Generating...')\n",
388
- "print('=' * 70)\n",
389
- "\n",
390
- "model.eval()\n",
391
- "\n",
392
- "torch.cuda.empty_cache()\n",
393
- "\n",
394
- "x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda'))\n",
395
- "\n",
396
- "with ctx:\n",
397
- " out = model.generate(x,\n",
398
- " number_of_tokens_to_generate,\n",
399
- " filter_logits_fn=top_k,\n",
400
- " filter_kwargs={'k': top_k_value},\n",
401
- " temperature=0.9,\n",
402
- " return_prime=False,\n",
403
- " verbose=True)\n",
404
- "\n",
405
- "output = out.tolist()\n",
406
- "\n",
407
- "print('=' * 70)\n",
408
- "print('Done!')\n",
409
- "print('=' * 70)\n",
410
- "\n",
411
- "#======================================================================\n",
412
- "print('Rendering results...')\n",
413
- "\n",
414
- "for i in range(number_of_batches_to_generate):\n",
415
- "\n",
416
- " print('=' * 70)\n",
417
- " print('Batch #', i)\n",
418
- " print('=' * 70)\n",
419
- "\n",
420
- " out1 = output[i]\n",
421
- "\n",
422
- " print('Sample INTs', out1[:12])\n",
423
- " print('=' * 70)\n",
424
- "\n",
425
- " if len(out1) != 0:\n",
426
- "\n",
427
- " song = out1\n",
428
- " song_f = []\n",
429
- "\n",
430
- " time = 0\n",
431
- " dur = 0\n",
432
- " vel = 90\n",
433
- " pitch = 0\n",
434
- " channel = 0\n",
435
- "\n",
436
- " patches = [0] * 16\n",
437
- " patches[0] = accompaniment_MIDI_patch_number\n",
438
- " patches[3] = melody_MIDI_patch_number\n",
439
- "\n",
440
- " for ss in song:\n",
441
- "\n",
442
- " if 0 < ss < 128:\n",
443
- "\n",
444
- " time += (ss * 32)\n",
445
- "\n",
446
- " if 128 < ss < 256:\n",
447
- "\n",
448
- " dur = (ss-128) * 32\n",
449
- "\n",
450
- " if 256 < ss < 512:\n",
451
- "\n",
452
- " pitch = (ss-256) % 128\n",
453
- "\n",
454
- " channel = (ss-256) // 128\n",
455
- "\n",
456
- " if channel == 1:\n",
457
- " channel = 3\n",
458
- " vel = 110 + (pitch % 12)\n",
459
- " song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number])\n",
460
- "\n",
461
- " else:\n",
462
- " vel = 80 + (pitch % 12)\n",
463
- " channel = 0\n",
464
- " song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number])\n",
465
- "\n",
466
- " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n",
467
- " output_signature = 'Melody2Song Seq2Seq Music Transformer',\n",
468
- " output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i),\n",
469
- " track_name='Project Los Angeles',\n",
470
- " list_of_MIDI_patches=patches\n",
471
- " )\n",
472
- " print('=' * 70)\n",
473
- " print('Displaying resulting composition...')\n",
474
- " print('=' * 70)\n",
475
- "\n",
476
- " fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i)\n",
477
- "\n",
478
- " if render_MIDI_to_audio:\n",
479
- " midi_audio = midi_to_colab_audio(fname + '.mid')\n",
480
- " display(Audio(midi_audio, rate=16000, normalize=False))\n",
481
- "\n",
482
- " TMIDIX.plot_ms_SONG(song_f, plot_title=fname)"
483
- ]
484
- },
485
- {
486
- "cell_type": "markdown",
487
- "metadata": {
488
- "id": "z87TlDTVl5cp"
489
- },
490
- "source": [
491
- "# Congrats! You did it! :)"
492
- ]
493
- }
494
- ],
495
- "metadata": {
496
- "accelerator": "GPU",
497
- "colab": {
498
- "gpuClass": "premium",
499
- "gpuType": "L4",
500
- "private_outputs": true,
501
- "provenance": [],
502
- "machine_shape": "hm"
503
- },
504
- "kernelspec": {
505
- "display_name": "Python 3",
506
- "name": "python3"
507
- },
508
- "language_info": {
509
- "codemirror_mode": {
510
- "name": "ipython",
511
- "version": 3
512
- },
513
- "file_extension": ".py",
514
- "mimetype": "text/x-python",
515
- "name": "python",
516
- "nbconvert_exporter": "python",
517
- "pygments_lexer": "ipython3",
518
- "version": "3.10.12"
519
- }
520
- },
521
- "nbformat": 4,
522
- "nbformat_minor": 0
523
- }