asigalov61 commited on
Commit
70080cf
·
verified ·
1 Parent(s): c5f1657

Upload Monster_Piano_Transformer_No_Velocity_Maker.ipynb

Browse files
training_code/Monster_Piano_Transformer_No_Velocity_Maker.ipynb ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "VGrGd6__l5ch"
7
+ },
8
+ "source": [
9
+ "# Monster Piano Transformer No Velocity Maker (ver. 1.0)\n",
10
+ "\n",
11
+ "***\n",
12
+ "\n",
13
+ "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
14
+ "\n",
15
+ "***\n",
16
+ "\n",
17
+ "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
18
+ "\n",
19
+ "***\n",
20
+ "\n",
21
+ "#### Project Los Angeles\n",
22
+ "\n",
23
+ "#### Tegridy Code 2025\n",
24
+ "\n",
25
+ "***"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "shLrgoXdl5cj"
32
+ },
33
+ "source": [
34
+ "# GPU check"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "X3rABEpKCO02"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "!nvidia-smi"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "0RcVC4btl5ck"
52
+ },
53
+ "source": [
54
+ "# Setup environment"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "viHgEaNACPTs"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {
72
+ "id": "vK40g6V_BTNj"
73
+ },
74
+ "outputs": [],
75
+ "source": [
76
+ "!pip install datasets\n",
77
+ "!pip install huggingface_hub\n",
78
+ "!pip install hf-transfer\n",
79
+ "!pip install ipywidgets\n",
80
+ "!pip install tqdm\n",
81
+ "\n",
82
+ "!sudo pip install einops\n",
83
+ "!sudo pip install torch-summary\n",
84
+ "!sudo pip install -U tqdm\n",
85
+ "!sudo pip install huggingface_hub"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {},
91
+ "source": [
92
+ "# Import Modules"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "id": "DzCOZU_gBiQV"
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "# Load modules and make data dir\n",
104
+ "\n",
105
+ "print('Loading modules...')\n",
106
+ "\n",
107
+ "import os\n",
108
+ "\n",
109
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n",
110
+ "\n",
111
+ "import pickle\n",
112
+ "import random\n",
113
+ "import secrets\n",
114
+ "import tqdm\n",
115
+ "import math\n",
116
+ "\n",
117
+ "import gc\n",
118
+ "\n",
119
+ "!set USE_FLASH_ATTENTION=1\n",
120
+ "os.environ['USE_FLASH_ATTENTION'] = '1'\n",
121
+ "\n",
122
+ "import torch\n",
123
+ "import torch.optim as optim\n",
124
+ "\n",
125
+ "from torch.utils.data import DataLoader, Dataset\n",
126
+ "\n",
127
+ "import matplotlib.pyplot as plt\n",
128
+ "\n",
129
+ "from torchsummary import summary\n",
130
+ "from sklearn import metrics\n",
131
+ "\n",
132
+ "from datasets import load_dataset\n",
133
+ "\n",
134
+ "from huggingface_hub import hf_hub_download\n",
135
+ "\n",
136
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n",
137
+ "\n",
138
+ "import TMIDIX\n",
139
+ "\n",
140
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n",
141
+ "\n",
142
+ "from x_transformer_1_23_2 import *\n",
143
+ "\n",
144
+ "torch.set_float32_matmul_precision('high')\n",
145
+ "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
146
+ "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
147
+ "torch.backends.cuda.enable_flash_sdp(True)\n",
148
+ "torch.backends.cuda.enable_cudnn_sdp(False)\n",
149
+ "\n",
150
+ "!set USE_FLASH_ATTENTION=1\n",
151
+ "\n",
152
+ "%cd /home/ubuntu/\n",
153
+ "\n",
154
+ "if not os.path.exists('/home/ubuntu/INTS'):\n",
155
+ " os.makedirs('/home/ubuntu/INTS')\n",
156
+ "\n",
157
+ "import random\n",
158
+ "\n",
159
+ "print('Done')\n",
160
+ "\n",
161
+ "print('Torch version:', torch.__version__)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "markdown",
166
+ "metadata": {
167
+ "id": "cd-51e9wooMs"
168
+ },
169
+ "source": [
170
+ "# Load Training Data"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "monster_piano = load_dataset('asigalov61/Monster-Piano')"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {},
185
+ "source": [
186
+ "# Prep Training Data"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "SEQ_LEN = 2048\n",
196
+ "PAD_IDX = 384 # Model pad index\n",
197
+ "\n",
198
+ "#==========================================================================\n",
199
+ "\n",
200
+ "print('=' * 70)\n",
201
+ "print('Loading data files...')\n",
202
+ "print('Please wait...')\n",
203
+ "print('=' * 70)\n",
204
+ "\n",
205
+ "train_data = set()\n",
206
+ "\n",
207
+ "chunks_counter = 0\n",
208
+ "\n",
209
+ "for entry in tqdm.tqdm(monster_piano['train']):\n",
210
+ "\n",
211
+ " score = entry['midi_score']\n",
212
+ " score = [t for t in score if t < 384]\n",
213
+ "\n",
214
+ " if 0 <= max(score) < PAD_IDX: # final data integrity check\n",
215
+ "\n",
216
+ " for i in range(0, len(score), SEQ_LEN-1024):\n",
217
+ " \n",
218
+ " chunk = score[i:i+SEQ_LEN+1]\n",
219
+ "\n",
220
+ " chunks_counter += 1\n",
221
+ "\n",
222
+ " if len(chunk) < SEQ_LEN+1:\n",
223
+ " chunk += [PAD_IDX] * (SEQ_LEN+1 - len(chunk))\n",
224
+ "\n",
225
+ " train_data.add(tuple(chunk))\n",
226
+ "\n",
227
+ " else:\n",
228
+ " print('Bad data!!!')\n",
229
+ "\n",
230
+ "#==========================================================================\n",
231
+ "\n",
232
+ "train_data = list(train_data)\n",
233
+ "\n",
234
+ "#==========================================================================\n",
235
+ "\n",
236
+ "print('Done!')\n",
237
+ "print('=' * 70)\n",
238
+ "print('Total number of main chunks:', chunks_counter)\n",
239
+ "print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))\n",
240
+ "print('=' * 70)\n",
241
+ "print('Randomizing train data...')\n",
242
+ "random.shuffle(train_data)\n",
243
+ "print('Done!')\n",
244
+ "print('=' * 70)\n",
245
+ "print('Total length of train data:', len(train_data))\n",
246
+ "print('=' * 70)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "metadata": {
252
+ "id": "VhZqBvqVl5cn"
253
+ },
254
+ "source": [
255
+ "# Setup model"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {
262
+ "id": "mfwp06xzzPZ5"
263
+ },
264
+ "outputs": [],
265
+ "source": [
266
+ "# Setup model\n",
267
+ "\n",
268
+ "# constants\n",
269
+ "\n",
270
+ "VALIDATE_EVERY = 500\n",
271
+ "SAVE_EVERY = 2500\n",
272
+ "GENERATE_EVERY = 1000\n",
273
+ "GENERATE_LENGTH = 512\n",
274
+ "PRINT_STATS_EVERY = 50\n",
275
+ "\n",
276
+ "NUM_EPOCHS = 10\n",
277
+ "\n",
278
+ "BATCH_SIZE = 116\n",
279
+ "GRADIENT_ACCUMULATE_EVERY = 1\n",
280
+ "\n",
281
+ "LEARNING_RATE = 1e-4\n",
282
+ "GRAD_CLIP = 1.5\n",
283
+ "\n",
284
+ "# instantiate the model\n",
285
+ "\n",
286
+ "model = TransformerWrapper(\n",
287
+ " num_tokens = PAD_IDX+1,\n",
288
+ " max_seq_len = SEQ_LEN,\n",
289
+ " attn_layers = Decoder(dim = 2048,\n",
290
+ " depth = 4,\n",
291
+ " heads = 32,\n",
292
+ " rotary_pos_emb = True,\n",
293
+ " attn_flash = True\n",
294
+ " )\n",
295
+ " )\n",
296
+ "\n",
297
+ "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n",
298
+ "\n",
299
+ "model.cuda()\n",
300
+ "\n",
301
+ "print('Done!')\n",
302
+ "\n",
303
+ "summary(model)\n",
304
+ "\n",
305
+ "# Dataloader\n",
306
+ "\n",
307
+ "def get_train_data_batch(tdata, index, seq_len, batch_size, pad_idx):\n",
308
+ "\n",
309
+ " batch = tdata[(index*batch_size):(index*batch_size)+batch_size]\n",
310
+ "\n",
311
+ " return torch.LongTensor(batch).cuda()\n",
312
+ "\n",
313
+ "# precision/optimizer/scaler\n",
314
+ "\n",
315
+ "dtype = torch.bfloat16\n",
316
+ "\n",
317
+ "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)\n",
318
+ "\n",
319
+ "optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
320
+ "\n",
321
+ "scaler = torch.amp.GradScaler('cuda')"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "metadata": {
327
+ "id": "xJPxxFiwl5cn"
328
+ },
329
+ "source": [
330
+ "# Train"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "metadata": {
337
+ "id": "HETGqz_6K1ml",
338
+ "scrolled": true
339
+ },
340
+ "outputs": [],
341
+ "source": [
342
+ "# Train the model\n",
343
+ "\n",
344
+ "train_losses = []\n",
345
+ "val_losses = []\n",
346
+ "\n",
347
+ "train_accs = []\n",
348
+ "val_accs = []\n",
349
+ "\n",
350
+ "nsteps = 0\n",
351
+ "\n",
352
+ "for ep in range(NUM_EPOCHS):\n",
353
+ "\n",
354
+ " print('=' * 70)\n",
355
+ " print('Randomizing train data...')\n",
356
+ " random.shuffle(train_data)\n",
357
+ " print('=' * 70)\n",
358
+ "\n",
359
+ " print('=' * 70)\n",
360
+ " print('Epoch #', ep)\n",
361
+ " print('=' * 70)\n",
362
+ "\n",
363
+ " NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY\n",
364
+ "\n",
365
+ " model.train()\n",
366
+ "\n",
367
+ " for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):\n",
368
+ "\n",
369
+ " optim.zero_grad()\n",
370
+ "\n",
371
+ " for j in range(GRADIENT_ACCUMULATE_EVERY):\n",
372
+ " with ctx:\n",
373
+ " loss, acc = model(get_train_data_batch(train_data, (i*GRADIENT_ACCUMULATE_EVERY)+j, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n",
374
+ " #loss = loss / GRADIENT_ACCUMULATE_EVERY\n",
375
+ " scaler.scale(loss).backward()\n",
376
+ "\n",
377
+ " if i % PRINT_STATS_EVERY == 0:\n",
378
+ " print(f'Training loss: {loss.item() * GRADIENT_ACCUMULATE_EVERY}')\n",
379
+ " print(f'Training acc: {acc.item()}')\n",
380
+ "\n",
381
+ " train_losses.append(loss.item() * GRADIENT_ACCUMULATE_EVERY)\n",
382
+ " train_accs.append(acc.item())\n",
383
+ "\n",
384
+ " scaler.unscale_(optim)\n",
385
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
386
+ " scaler.step(optim)\n",
387
+ " scaler.update()\n",
388
+ "\n",
389
+ " nsteps += 1\n",
390
+ "\n",
391
+ " if i % VALIDATE_EVERY == 0:\n",
392
+ " model.eval()\n",
393
+ " with torch.no_grad():\n",
394
+ " with ctx:\n",
395
+ " val_loss, val_acc = model(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n",
396
+ "\n",
397
+ " print(f'Validation loss: {val_loss.item()}')\n",
398
+ " print(f'Validation acc: {val_acc.item()}')\n",
399
+ "\n",
400
+ " val_losses.append(val_loss.item())\n",
401
+ " val_accs.append(val_acc.item())\n",
402
+ "\n",
403
+ " print('Plotting training loss graph...')\n",
404
+ "\n",
405
+ " tr_loss_list = train_losses\n",
406
+ " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n",
407
+ " plt.show()\n",
408
+ " plt.close()\n",
409
+ " print('Done!')\n",
410
+ "\n",
411
+ " print('Plotting training acc graph...')\n",
412
+ "\n",
413
+ " tr_loss_list = train_accs\n",
414
+ " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n",
415
+ " plt.show()\n",
416
+ " plt.close()\n",
417
+ " print('Done!')\n",
418
+ "\n",
419
+ " print('Plotting validation loss graph...')\n",
420
+ " tr_loss_list = val_losses\n",
421
+ " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n",
422
+ " plt.show()\n",
423
+ " plt.close()\n",
424
+ " print('Done!')\n",
425
+ "\n",
426
+ " print('Plotting validation acc graph...')\n",
427
+ " tr_loss_list = val_accs\n",
428
+ " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n",
429
+ " plt.show()\n",
430
+ " plt.close()\n",
431
+ " print('Done!')\n",
432
+ "\n",
433
+ " model.train()\n",
434
+ "\n",
435
+ " if i % GENERATE_EVERY == 0:\n",
436
+ " model.eval()\n",
437
+ "\n",
438
+ " inp = random.choice(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))[:GENERATE_LENGTH]\n",
439
+ "\n",
440
+ " print(inp)\n",
441
+ "\n",
442
+ " with ctx:\n",
443
+ " sample = model.generate(inp[None, ...], GENERATE_LENGTH)\n",
444
+ "\n",
445
+ " print(sample)\n",
446
+ "\n",
447
+ " data = sample.tolist()[0]\n",
448
+ "\n",
449
+ " print('Sample INTs', data[:15])\n",
450
+ "\n",
451
+ " if len(data) != 0:\n",
452
+ "\n",
453
+ " song = data\n",
454
+ " song_f = []\n",
455
+ "\n",
456
+ " time = 0\n",
457
+ " dur = 1\n",
458
+ " vel = 90\n",
459
+ " pitch = 60\n",
460
+ " channel = 0\n",
461
+ " patch = 0\n",
462
+ "\n",
463
+ " patches = [0] * 16\n",
464
+ "\n",
465
+ " for m in song:\n",
466
+ "\n",
467
+ " if 0 <= m < 128:\n",
468
+ " time += m * 32\n",
469
+ " \n",
470
+ " elif 128 < m < 256:\n",
471
+ " dur = (m-128) * 32\n",
472
+ " \n",
473
+ " elif 256 < m < 384:\n",
474
+ " pitch = (m-256)\n",
475
+ " \n",
476
+ " song_f.append(['note', time, dur, 0, pitch, vel, 0])\n",
477
+ "\n",
478
+ "\n",
479
+ " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n",
480
+ " output_signature = 'Monster Piano Transformer',\n",
481
+ " output_file_name = '/home/ubuntu/Monster-Piano-Transformer-Composition',\n",
482
+ " track_name='Project Los Angeles',\n",
483
+ " list_of_MIDI_patches=patches\n",
484
+ " )\n",
485
+ "\n",
486
+ " print('Done!')\n",
487
+ "\n",
488
+ " model.train()\n",
489
+ "\n",
490
+ " if i % SAVE_EVERY == 0:\n",
491
+ "\n",
492
+ " print('Saving model progress. Please wait...')\n",
493
+ " print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n",
494
+ "\n",
495
+ " fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n",
496
+ "\n",
497
+ " torch.save(model.state_dict(), fname)\n",
498
+ "\n",
499
+ " data = [train_losses, train_accs, val_losses, val_accs]\n",
500
+ "\n",
501
+ " TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accs')\n",
502
+ "\n",
503
+ " print('Done!')"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "metadata": {
509
+ "id": "wBkMH2gWl5co"
510
+ },
511
+ "source": [
512
+ "# Final Save"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "metadata": {
519
+ "id": "gjBJnzZxWslL"
520
+ },
521
+ "outputs": [],
522
+ "source": [
523
+ "print('Saving model progress. Please wait...')\n",
524
+ "print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n",
525
+ "\n",
526
+ "fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n",
527
+ "\n",
528
+ "torch.save(model.state_dict(), fname)\n",
529
+ "#torch.save(optim.state_dict(), fname+'_opt')\n",
530
+ "\n",
531
+ "print('Done!')\n",
532
+ "\n",
533
+ "data = [train_losses, train_accs, val_losses, val_accs]\n",
534
+ "\n",
535
+ "TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accuracies')\n",
536
+ "\n",
537
+ "# Save training loss graph\n",
538
+ "\n",
539
+ "plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')\n",
540
+ "plt.savefig('/home/ubuntu/training_loss_graph.png')\n",
541
+ "plt.close()\n",
542
+ "print('Done!')\n",
543
+ "\n",
544
+ "# Save training acc graph\n",
545
+ "\n",
546
+ "plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')\n",
547
+ "plt.savefig('/home/ubuntu/training_acc_graph.png')\n",
548
+ "plt.close()\n",
549
+ "print('Done!')\n",
550
+ "\n",
551
+ "# Save validation loss graph\n",
552
+ "\n",
553
+ "plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')\n",
554
+ "plt.savefig('/home/ubuntu/validation_loss_graph.png')\n",
555
+ "plt.close()\n",
556
+ "print('Done!')\n",
557
+ "\n",
558
+ "# Save validation acc graph\n",
559
+ "\n",
560
+ "plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')\n",
561
+ "plt.savefig('/home/ubuntu/validation_acc_graph.png')\n",
562
+ "plt.close()\n",
563
+ "print('Done!')"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "markdown",
568
+ "metadata": {
569
+ "id": "feXay_Ed7mG5"
570
+ },
571
+ "source": [
572
+ "# Eval"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "metadata": {
579
+ "id": "SA8qQSzbWslM"
580
+ },
581
+ "outputs": [],
582
+ "source": [
583
+ "hf_hub_download(repo_id='asigalov61/Monster-Piano-Transformer',\n",
584
+ " filename='Monster_Piano_Transformer_No_Velocity_Trained_Model_161960_steps_0.7775_loss_0.7661_acc.pth',\n",
585
+ " local_dir='/home/ubuntu/Models/',\n",
586
+ " )"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "metadata": {
593
+ "id": "gSvqSRLaWslM"
594
+ },
595
+ "outputs": [],
596
+ "source": [
597
+ "SEQ_LEN = 2048\n",
598
+ "PAD_IDX = 384\n",
599
+ "\n",
600
+ "model = TransformerWrapper(\n",
601
+ " num_tokens = PAD_IDX+1,\n",
602
+ " max_seq_len = SEQ_LEN,\n",
603
+ " attn_layers = Decoder(dim = 2048,\n",
604
+ " depth = 4,\n",
605
+ " heads = 32,\n",
606
+ " rotary_pos_emb = True,\n",
607
+ " attn_flash = True\n",
608
+ " )\n",
609
+ " )\n",
610
+ "\n",
611
+ "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n",
612
+ "\n",
613
+ "print('=' * 70)\n",
614
+ "print('Loading model checkpoint...')\n",
615
+ "\n",
616
+ "model_path = 'Models/Monster_Piano_Transformer_No_Velocity_Trained_Model_161960_steps_0.7775_loss_0.7661_acc.pth'\n",
617
+ "\n",
618
+ "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
619
+ "\n",
620
+ "print('=' * 70)\n",
621
+ "\n",
622
+ "model.cuda()\n",
623
+ "model.eval()\n",
624
+ "\n",
625
+ "print('Done!')\n",
626
+ "\n",
627
+ "summary(model)\n",
628
+ "\n",
629
+ "dtype = torch.bfloat16\n",
630
+ "\n",
631
+ "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": null,
637
+ "metadata": {
638
+ "id": "enHpaHxaWslM"
639
+ },
640
+ "outputs": [],
641
+ "source": [
642
+ "midi_file = '/home/ubuntu/tegridy-tools/tegridy-tools/seed2.mid'\n",
643
+ "\n",
644
+ "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n",
645
+ "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]\n",
646
+ "escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)\n",
647
+ "\n",
648
+ "sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)\n",
649
+ "zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)\n",
650
+ "\n",
651
+ "cscore = TMIDIX.chordify_score([1000, zscore])\n",
652
+ "\n",
653
+ "score = []\n",
654
+ "\n",
655
+ "pc = cscore[0]\n",
656
+ "\n",
657
+ "for c in cscore:\n",
658
+ " score.append(max(0, min(127, c[0][1]-pc[0][1])))\n",
659
+ "\n",
660
+ " for n in c:\n",
661
+ " score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])\n",
662
+ "\n",
663
+ " pc = c\n",
664
+ "\n",
665
+ "print('Done!')\n",
666
+ "print('=' * 70)\n",
667
+ "print(len(score))\n",
668
+ "print('=' * 70)"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": null,
674
+ "metadata": {
675
+ "id": "naf65RxUXwDg"
676
+ },
677
+ "outputs": [],
678
+ "source": [
679
+ "x = torch.LongTensor(score[:1024]).cuda()\n",
680
+ "\n",
681
+ "with ctx:\n",
682
+ " out = model.generate(x,\n",
683
+ " 1024,\n",
684
+ " temperature=0.9,\n",
685
+ " #filter_logits_fn=top_k,\n",
686
+ " #filter_kwargs={'k': 15},\n",
687
+ " return_prime=True,\n",
688
+ " verbose=True)\n",
689
+ "\n",
690
+ "y = out.tolist()\n",
691
+ "\n",
692
+ "print('---------------')"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "metadata": {
699
+ "id": "tlBzqWpAnZna"
700
+ },
701
+ "outputs": [],
702
+ "source": [
703
+ "#@title Test INTs\n",
704
+ "\n",
705
+ "data = y[0]\n",
706
+ "\n",
707
+ "print('Sample INTs', data[:15])\n",
708
+ "\n",
709
+ "if len(data) != 0:\n",
710
+ "\n",
711
+ " song = data\n",
712
+ " song_f = []\n",
713
+ "\n",
714
+ " time = 0\n",
715
+ " dur = 1\n",
716
+ " vel = 90\n",
717
+ " pitch = 60\n",
718
+ " channel = 0\n",
719
+ " patch = 0\n",
720
+ "\n",
721
+ " patches = [0] * 16\n",
722
+ "\n",
723
+ " for m in song:\n",
724
+ "\n",
725
+ " if 0 <= m < 128:\n",
726
+ " time += m * 32\n",
727
+ "\n",
728
+ " elif 128 < m < 256:\n",
729
+ " dur = (m-128) * 32\n",
730
+ "\n",
731
+ " elif 256 < m < 384:\n",
732
+ " pitch = (m-256)\n",
733
+ "\n",
734
+ " song_f.append(['note', time, dur, 0, pitch, vel, 0])\n",
735
+ "\n",
736
+ "\n",
737
+ " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n",
738
+ " output_signature = 'Monster Piano Transformer',\n",
739
+ " output_file_name = '/home/ubuntu/Monster-Piano-Transformer-Composition',\n",
740
+ " track_name='Project Los Angeles',\n",
741
+ " list_of_MIDI_patches=patches\n",
742
+ " )\n",
743
+ "\n",
744
+ "print('Done!')"
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "metadata": {
751
+ "id": "al3TDlH7T8m7"
752
+ },
753
+ "outputs": [],
754
+ "source": [
755
+ "tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n",
756
+ "\n",
757
+ "cos_sim = metrics.pairwise_distances(\n",
758
+ " tok_emb, metric='cosine'\n",
759
+ ")\n",
760
+ "plt.figure(figsize=(7, 7))\n",
761
+ "plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n",
762
+ "im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n",
763
+ "plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n",
764
+ "plt.xlabel(\"Position\")\n",
765
+ "plt.ylabel(\"Position\")\n",
766
+ "plt.tight_layout()\n",
767
+ "plt.plot()\n",
768
+ "plt.savefig(\"/home/ubuntu/Monster-Piano-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "markdown",
773
+ "metadata": {
774
+ "id": "z87TlDTVl5cp"
775
+ },
776
+ "source": [
777
+ "# Congrats! You did it! :)"
778
+ ]
779
+ }
780
+ ],
781
+ "metadata": {
782
+ "accelerator": "GPU",
783
+ "colab": {
784
+ "gpuClass": "premium",
785
+ "gpuType": "T4",
786
+ "private_outputs": true,
787
+ "provenance": []
788
+ },
789
+ "kernelspec": {
790
+ "display_name": "Python 3 (ipykernel)",
791
+ "language": "python",
792
+ "name": "python3"
793
+ },
794
+ "language_info": {
795
+ "codemirror_mode": {
796
+ "name": "ipython",
797
+ "version": 3
798
+ },
799
+ "file_extension": ".py",
800
+ "mimetype": "text/x-python",
801
+ "name": "python",
802
+ "nbconvert_exporter": "python",
803
+ "pygments_lexer": "ipython3",
804
+ "version": "3.10.12"
805
+ }
806
+ },
807
+ "nbformat": 4,
808
+ "nbformat_minor": 4
809
+ }