jkeisling commited on
Commit
f747ce5
·
1 Parent(s): fb24f54

Fix training objective; lower model size

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. gpt.ipynb +315 -203
.gitignore CHANGED
@@ -2,6 +2,9 @@
2
  checkpoints/
3
  datasets/
4
 
 
 
 
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
7
  *.py[cod]
 
2
  checkpoints/
3
  datasets/
4
 
5
+ # Training Tensorboard runs
6
+ runs/
7
+
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
10
  *.py[cod]
gpt.ipynb CHANGED
@@ -10,20 +10,19 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 36,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
17
  "import os\n",
18
  "\n",
19
- "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
20
  "if not os.path.isfile(\"./datasets/corpora/shakespeare.txt\"):\n",
21
- " !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt > datasets/corpora/shakespeare.txt"
22
  ]
23
  },
24
  {
25
  "cell_type": "code",
26
- "execution_count": 10,
27
  "metadata": {},
28
  "outputs": [],
29
  "source": [
@@ -31,6 +30,21 @@
31
  " text = f.read()"
32
  ]
33
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  {
35
  "attachments": {},
36
  "cell_type": "markdown",
@@ -41,16 +55,86 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": 11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  "metadata": {},
46
  "outputs": [
47
  {
48
  "data": {
49
  "text/plain": [
50
- "<torch._C.Generator at 0x7f7b543cb430>"
51
  ]
52
  },
53
- "execution_count": 11,
54
  "metadata": {},
55
  "output_type": "execute_result"
56
  }
@@ -71,7 +155,7 @@
71
  },
72
  {
73
  "cell_type": "code",
74
- "execution_count": 12,
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
@@ -85,16 +169,27 @@
85
  },
86
  {
87
  "cell_type": "code",
88
- "execution_count": 13,
89
  "metadata": {},
90
- "outputs": [],
 
 
 
 
 
 
 
 
91
  "source": [
92
  "# Tensorify data, put it in dataset\n",
93
  "data = torch.tensor(encode_text(text), dtype=torch.int32)\n",
94
  "\n",
95
- "split_idx = int(0.9 * len(data))\n",
96
- "train_data = data[:split_idx]\n",
97
- "test_data = data[split_idx:]"
 
 
 
98
  ]
99
  },
100
  {
@@ -107,7 +202,7 @@
107
  },
108
  {
109
  "cell_type": "code",
110
- "execution_count": 31,
111
  "metadata": {},
112
  "outputs": [],
113
  "source": [
@@ -117,92 +212,15 @@
117
  " self.context_size = context_size\n",
118
  " \n",
119
  " def __len__(self):\n",
120
- " return len(self.data_tensor)\n",
121
  "\n",
122
  " def __getitem__(self, index):\n",
123
- " if index < self.context_size:\n",
124
- " x = F.pad(self.data_tensor[:index], (self.context_size - index, 0), value=0)\n",
125
- " else:\n",
126
- " x = self.data_tensor[index - self.context_size:index]\n",
127
  " \n",
128
- " y = self.data_tensor[index]\n",
129
  " return x, y"
130
  ]
131
  },
132
- {
133
- "attachments": {},
134
- "cell_type": "markdown",
135
- "metadata": {},
136
- "source": [
137
- "NOTE 2023-03-25: I think this is bugged, and that's the reason the training loss is so damn high. Testing:"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": 34,
143
- "metadata": {},
144
- "outputs": [
145
- {
146
- "name": "stdout",
147
- "output_type": "stream",
148
- "text": [
149
- "Step 0:\n",
150
- "[0, 0, 0, 0, 0, 0, 0, 0]\n",
151
- "---\n",
152
- "[0, 0, 0, 0, 0, 0, 0, 70]\n",
153
- "---\n",
154
- "['F', 'i']\n",
155
- "Step 1:\n",
156
- "[0, 0, 0, 0, 0, 0, 70, 105]\n",
157
- "---\n",
158
- "[0, 0, 0, 0, 0, 70, 105, 114]\n",
159
- "---\n",
160
- "['r', 's']\n",
161
- "Step 2:\n",
162
- "[0, 0, 0, 0, 70, 105, 114, 115]\n",
163
- "---\n",
164
- "[0, 0, 0, 70, 105, 114, 115, 116]\n",
165
- "---\n",
166
- "['t', ' ']\n",
167
- "Step 3:\n",
168
- "[0, 0, 70, 105, 114, 115, 116, 32]\n",
169
- "---\n",
170
- "[0, 70, 105, 114, 115, 116, 32, 67]\n",
171
- "---\n",
172
- "['C', 'i']\n",
173
- "Step 4:\n",
174
- "[70, 105, 114, 115, 116, 32, 67, 105]\n",
175
- "---\n",
176
- "[105, 114, 115, 116, 32, 67, 105, 116]\n",
177
- "---\n",
178
- "['t', 'i']\n",
179
- "Step 5:\n",
180
- "[114, 115, 116, 32, 67, 105, 116, 105]\n",
181
- "---\n",
182
- "[115, 116, 32, 67, 105, 116, 105, 122]\n",
183
- "---\n",
184
- "['z', 'e']\n"
185
- ]
186
- }
187
- ],
188
- "source": [
189
- "train_dataset = TextDataset(train_data, 8)\n",
190
- "train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)\n",
191
- "\n",
192
- "step = 0\n",
193
- "for x, y in train_dataloader:\n",
194
- " print(f\"Step {step}:\")\n",
195
- " for b in x.tolist():\n",
196
- " print(b)\n",
197
- " print(\"---\")\n",
198
- "\n",
199
- " print(decode_text(y.tolist()))\n",
200
- " step += 1\n",
201
- " if step > 5:\n",
202
- " break\n",
203
- "\n"
204
- ]
205
- },
206
  {
207
  "attachments": {},
208
  "cell_type": "markdown",
@@ -213,7 +231,7 @@
213
  },
214
  {
215
  "cell_type": "code",
216
- "execution_count": 8,
217
  "metadata": {},
218
  "outputs": [],
219
  "source": [
@@ -226,15 +244,14 @@
226
  " self.num_heads = num_heads\n",
227
  " self.d_k = embed_dim // num_heads\n",
228
  "\n",
229
- " self.Q = nn.Linear(embed_dim, embed_dim, bias=bias)\n",
230
- " self.K = nn.Linear(embed_dim, embed_dim, bias=bias)\n",
231
- " self.V = nn.Linear(embed_dim, embed_dim, bias=bias)\n",
232
  "\n",
233
  " self.dropout = nn.Dropout(dropout)\n",
234
  " self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
235
- " nn.init.kaiming_normal_(self.out_proj.weight, mode='fan_in', nonlinearity='linear')\n",
236
  "\n",
237
- " def forward(self, query, key, value, key_padding_mask=None):\n",
238
  " batch_size = query.size(0)\n",
239
  "\n",
240
  " # Apply linear layers\n",
@@ -251,7 +268,7 @@
251
  " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, C, C]\n",
252
  "\n",
253
  " # Apply mask, if necessary\n",
254
- " if key_padding_mask is not None:\n",
255
  " \"\"\"\n",
256
  " MAY BE WORTH DEBUGGING\n",
257
  "\n",
@@ -263,7 +280,7 @@
263
  " key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len]\n",
264
  " \"\"\"\n",
265
  " # Apply the mask to attention scores\n",
266
- " scores = scores.masked_fill(key_padding_mask, float('-inf'))\n",
267
  "\n",
268
  " # Scale by sqrt(k)\n",
269
  " attn = F.softmax(scores, dim=-1)\n",
@@ -275,13 +292,13 @@
275
  " out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)\n",
276
  " # Project: give attention \"time to think\". Maybe this should be part of a different module but whatever\n",
277
  " out = self.out_proj(out)\n",
278
- " return(out)\n",
279
  "\n"
280
  ]
281
  },
282
  {
283
  "cell_type": "code",
284
- "execution_count": 9,
285
  "metadata": {},
286
  "outputs": [],
287
  "source": [
@@ -290,9 +307,9 @@
290
  " super().__init__()\n",
291
  " self.net = nn.Sequential(\n",
292
  " nn.Linear(embed_dim, 4 * embed_dim),\n",
293
- " nn.ReLU(),\n",
294
- " nn.Dropout(dropout)\n",
295
  " nn.Linear(4 * embed_dim, embed_dim),\n",
 
296
  " )\n",
297
  "\n",
298
  " def forward(self, x):\n",
@@ -301,7 +318,7 @@
301
  },
302
  {
303
  "cell_type": "code",
304
- "execution_count": 10,
305
  "metadata": {},
306
  "outputs": [],
307
  "source": [
@@ -311,6 +328,7 @@
311
  " super(Block, self).__init__() \n",
312
  " self.register_buffer(\"mask\", mask)\n",
313
  " self.head = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)\n",
 
314
  " self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)\n",
315
  " self.ln1 = nn.LayerNorm(embed_dim)\n",
316
  " self.ln2 = nn.LayerNorm(embed_dim)\n",
@@ -318,58 +336,57 @@
318
  " def forward(self, x):\n",
319
  " # Residual connections\n",
320
  " x = self.ln1(x)\n",
321
- " x = x + self.head.forward(x, x, x, key_padding_mask=self.mask) \n",
 
322
  " out = x + self.ffwd(self.ln2(x))\n",
323
  " return out\n"
324
  ]
325
  },
326
  {
327
  "cell_type": "code",
328
- "execution_count": 11,
329
  "metadata": {},
330
  "outputs": [],
331
  "source": [
332
  "class GPT(nn.Module):\n",
333
- " def __init__(self, embedding_dim, vocab_size, context_size, lr=1e-3):\n",
334
- " # Inherit PyTorch stuff\n",
335
  " super(GPT, self).__init__()\n",
336
  "\n",
337
- " # Save variables for later\n",
338
  " self.embedding_dim = embedding_dim\n",
339
  " self.output_dim = vocab_size\n",
340
  " self.context_size = context_size\n",
341
  "\n",
342
- " # Initialize layers. Sadly this breaks the whole \"self.layers: concept but whatever\n",
 
 
 
343
  " self.tok_embed = nn.Embedding(vocab_size, embedding_dim)\n",
344
  " self.pos_embed = nn.Embedding(context_size, embedding_dim)\n",
345
  "\n",
346
- " NUM_HEADS=6\n",
347
- " NUM_LAYERS=6\n",
348
- " \n",
349
  " mask = torch.tril(torch.ones(self.context_size, self.context_size)).bool()\n",
350
  " mask = ~mask\n",
351
- " self.register_buffer(mask)\n",
352
  "\n",
353
  " self.blocks = nn.Sequential(\n",
354
- " *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask) for _ in range(NUM_LAYERS)],\n",
355
- " nn.Dropout(0.2)\n",
356
  " )\n",
357
  "\n",
 
358
  " # Final feed-forward layer from embeddings\n",
359
- " self.ffwd = nn.Linear(embedding_dim, out_features=vocab_size)\n",
360
  "\n",
361
  " def forward(self, x):\n",
362
  " tok_embed = self.tok_embed(x)\n",
363
- " tok_embed = tok_embed.view(-1, self.context_size, self.embedding_dim)\n",
364
- " pos_embed = self.pos_embed(torch.arange(0, self.context_size, device=\"cuda\")).unsqueeze(0)\n",
 
365
  " x = tok_embed + pos_embed\n",
366
  "\n",
367
- " # The actual attention is all you need here!\n",
368
- " # B*C*C cutting out the future\n",
369
  " x = self.blocks(x)\n",
 
370
  "\n",
371
- " preds = self.ffwd(x)\n",
372
- " return(preds)\n",
373
  " \n",
374
  " def infer(self, x):\n",
375
  " with torch.no_grad():\n",
@@ -387,93 +404,114 @@
387
  },
388
  {
389
  "cell_type": "code",
390
- "execution_count": 19,
391
  "metadata": {},
392
  "outputs": [],
393
  "source": [
394
  "def compute_loss(model, criterion, x, y):\n",
395
  " logits = model(x)\n",
396
- " last_logits = logits[:, -1, :]\n",
397
- " log_probs = nn.LogSoftmax(dim=1)(last_logits)\n",
398
- " loss = criterion(log_probs, y.view(-1).long())\n",
 
399
  " return loss"
400
  ]
401
  },
402
  {
403
  "cell_type": "code",
404
- "execution_count": 47,
405
  "metadata": {},
406
  "outputs": [],
407
  "source": [
408
- "EMBEDDING_NDIM = 384\n",
409
- "VOCAB_SIZE = 128\n",
410
- "BATCH_SIZE=64\n",
411
- "# \"Context window\"\n",
412
- "BLOCK_SIZE=256\n",
413
- "LR=1e-3\n",
414
  "\n",
415
  "train_dataset = TextDataset(train_data, BLOCK_SIZE)\n",
416
- "test_dataset = TextDataset(train_data, BLOCK_SIZE)\n",
417
  "\n",
418
  "# Janky training code\n",
419
  "model = GPT(\n",
420
  " embedding_dim=EMBEDDING_NDIM, \n",
421
  " vocab_size=VOCAB_SIZE,\n",
422
  " context_size=BLOCK_SIZE,\n",
423
- " lr=LR\n",
424
  " )\n",
425
  "\n",
426
  "model = model.to('cuda')\n",
427
  "optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
428
- "# TODO Fix this!\n",
429
- "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.2)\n",
430
- "criterion = nn.NLLLoss()"
 
431
  ]
432
  },
433
  {
434
  "cell_type": "code",
435
- "execution_count": 50,
436
  "metadata": {},
437
  "outputs": [
438
  {
439
  "name": "stdout",
440
  "output_type": "stream",
441
  "text": [
442
- "Step 0; loss: 3.3686537742614746\n",
443
- "Step 100; loss: 3.3535483678181968\n",
444
- "Step 200; loss: 3.3484479188919067\n",
445
- "Step 300; loss: 3.344235420227051\n",
446
- "Step 400; loss: 3.338580369949341\n",
447
- "Step 500; loss: 3.330465725490025\n",
448
- "Step 600; loss: 3.333183079957962\n",
449
- "Step 700; loss: 3.3319032986958823\n",
450
- "Step 800; loss: 3.332624101638794\n",
451
- "Step 900; loss: 3.3325188810175117\n",
452
- "Step 1000; loss: 3.331260542074839\n",
453
- "Step 1100; loss: 3.3311657355381894\n"
454
- ]
455
- },
456
- {
457
- "ename": "KeyboardInterrupt",
458
- "evalue": "",
459
- "output_type": "error",
460
- "traceback": [
461
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
462
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
463
- "\u001b[1;32m/home/ritsuko/projects/ai/micrograd/gpt.ipynb Cell 20\u001b[0m in \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ritsuko/projects/ai/micrograd/gpt.ipynb#X24sZmlsZQ%3D%3D?line=24'>25</a>\u001b[0m \u001b[39m# Backward pass\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ritsuko/projects/ai/micrograd/gpt.ipynb#X24sZmlsZQ%3D%3D?line=25'>26</a>\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/ritsuko/projects/ai/micrograd/gpt.ipynb#X24sZmlsZQ%3D%3D?line=26'>27</a>\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ritsuko/projects/ai/micrograd/gpt.ipynb#X24sZmlsZQ%3D%3D?line=27'>28</a>\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ritsuko/projects/ai/micrograd/gpt.ipynb#X24sZmlsZQ%3D%3D?line=28'>29</a>\u001b[0m scheduler\u001b[39m.\u001b[39mstep()\n",
464
- "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:396\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 388\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 389\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 390\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 394\u001b[0m create_graph\u001b[39m=\u001b[39mcreate_graph,\n\u001b[1;32m 395\u001b[0m inputs\u001b[39m=\u001b[39minputs)\n\u001b[0;32m--> 396\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs)\n",
465
- "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m 175\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
466
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  ]
468
  }
469
  ],
470
  "source": [
471
- "from torch.utils.\n",
 
472
  "EPOCHS = 1\n",
473
  "STEPS = 5000\n",
474
  "VAL_INTERVAL = 100\n",
475
  "\n",
476
- "losses = []\n",
477
  "model.train()\n",
478
  "\n",
479
  "train_dataloader = DataLoader(\n",
@@ -485,7 +523,10 @@
485
  "\n",
486
  "test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4, shuffle=True)\n",
487
  "\n",
 
 
488
  "step = 0\n",
 
489
  "for epoch in range(EPOCHS):\n",
490
  " for data, target in train_dataloader:\n",
491
  " data = data.to('cuda')\n",
@@ -497,11 +538,16 @@
497
  " optimizer.zero_grad()\n",
498
  " loss.backward()\n",
499
  " optimizer.step()\n",
500
- " scheduler.step()\n",
501
  "\n",
502
- " losses.append(loss.cpu().detach().numpy())\n",
 
503
  "\n",
 
504
  " if step % VAL_INTERVAL == 0:\n",
 
 
 
505
  " with torch.no_grad():\n",
506
  " model.eval()\n",
507
  " for x, y in test_dataloader:\n",
@@ -514,18 +560,23 @@
514
  " if total_samples > 10:\n",
515
  " break\n",
516
  "\n",
517
- " average_loss = total_loss / total_samples\n",
518
- " print(f\"Step {step}; loss: {average_loss}\")\n",
519
- " model.train()\n",
 
 
 
520
  "\n",
521
  " step += 1\n",
522
  " if step >= STEPS:\n",
523
- " break\n"
 
 
524
  ]
525
  },
526
  {
527
  "cell_type": "code",
528
- "execution_count": 15,
529
  "metadata": {},
530
  "outputs": [],
531
  "source": [
@@ -534,7 +585,7 @@
534
  },
535
  {
536
  "cell_type": "code",
537
- "execution_count": 36,
538
  "metadata": {},
539
  "outputs": [],
540
  "source": [
@@ -568,16 +619,16 @@
568
  },
569
  {
570
  "cell_type": "code",
571
- "execution_count": 37,
572
  "metadata": {},
573
  "outputs": [
574
  {
575
  "data": {
576
  "text/plain": [
577
- "2399"
578
  ]
579
  },
580
- "execution_count": 37,
581
  "metadata": {},
582
  "output_type": "execute_result"
583
  }
@@ -589,14 +640,14 @@
589
  },
590
  {
591
  "cell_type": "code",
592
- "execution_count": 51,
593
  "metadata": {},
594
  "outputs": [
595
  {
596
  "name": "stdout",
597
  "output_type": "stream",
598
  "text": [
599
- "3.4188449382781982\n"
600
  ]
601
  }
602
  ],
@@ -605,16 +656,17 @@
605
  "total_loss = 0.0\n",
606
  "total_samples = 0\n",
607
  "\n",
608
- "test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4)\n",
 
609
  "with torch.no_grad():\n",
610
- " for x, y in test_dataloader:\n",
611
  " x = x.to(\"cuda\")\n",
612
  " y = y.to(\"cuda\")\n",
613
  "\n",
614
  " batch_loss = compute_loss(model, criterion, x, y)\n",
615
  " total_loss += batch_loss.item() * x.size(0)\n",
616
  " total_samples += x.size(0)\n",
617
- " if total_samples > 100:\n",
618
  " break\n",
619
  "\n",
620
  " average_loss = total_loss / total_samples\n",
@@ -623,44 +675,98 @@
623
  },
624
  {
625
  "cell_type": "code",
626
- "execution_count": null,
627
  "metadata": {},
628
- "outputs": [],
629
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  },
631
  {
632
  "attachments": {},
633
  "cell_type": "markdown",
634
  "metadata": {},
635
  "source": [
636
- "Finally, we generate:"
637
  ]
638
  },
639
  {
640
  "cell_type": "code",
641
- "execution_count": 52,
642
  "metadata": {},
643
  "outputs": [
644
  {
645
  "name": "stdout",
646
  "output_type": "stream",
647
  "text": [
648
- ",n aon mr\n",
649
- "nr\n",
650
- "egtel s.mangtVk h\n",
651
- " -hinSfii ol ihIraddeioi akpshaC.n trU d aamooaa eoeEhl:daoUabo'm-fddE auh hpyHs wv'erstiInnmwt hnAuNu ufl\n",
652
- "I: rl.T l!eool'lIhl:aynet nna:i yaneehtea hdel\n",
653
- " hse l;imi\n",
654
- " hgy f iuto eoh gBum.umhemvt\n",
655
- "a hFo lNsute oaaenh;byeon"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  ]
657
  }
658
  ],
659
  "source": [
660
  "g_cuda = torch.Generator(device='cuda')\n",
661
  "\n",
662
- "contexts = torch.tensor(encode_text(\"God\"), dtype=torch.int32).to('cuda')\n",
663
- "GEN_LENGTH=256\n",
 
 
 
 
664
  "\n",
665
  "model.eval()\n",
666
  "for i in range(GEN_LENGTH):\n",
@@ -668,13 +774,19 @@
668
  " # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
669
  " #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
670
  " x = contexts[-BLOCK_SIZE:]\n",
671
- " x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), \"constant\", 0).unsqueeze(0) # B*T\n",
 
 
 
 
672
  " preds = model.infer(x)\n",
673
  " preds = preds.squeeze(0)\n",
674
  " probs = torch.softmax(preds, dim=-1)\n",
675
  "\n",
676
  " # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
677
- " next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
 
 
678
  " #context = torch.cat(context, next_char)\n",
679
  " contexts = torch.cat((contexts, next_char), dim=0)\n",
680
  " print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 39,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
17
  "import os\n",
18
  "\n",
 
19
  "if not os.path.isfile(\"./datasets/corpora/shakespeare.txt\"):\n",
20
+ " !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O datasets/corpora/shakespeare.txt"
21
  ]
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": 40,
26
  "metadata": {},
27
  "outputs": [],
28
  "source": [
 
30
  " text = f.read()"
31
  ]
32
  },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 41,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "# Putting hyperparameters at the top because I learned this the hard way\n",
40
+ "# 64 * NUM_HEADS\n",
41
+ "EMBEDDING_NDIM=256\n",
42
+ "VOCAB_SIZE=128\n",
43
+ "BATCH_SIZE=64\n",
44
+ "# \"Context window\"\n",
45
+ "BLOCK_SIZE=256"
46
+ ]
47
+ },
48
  {
49
  "attachments": {},
50
  "cell_type": "markdown",
 
55
  },
56
  {
57
  "cell_type": "code",
58
+ "execution_count": 42,
59
+ "metadata": {},
60
+ "outputs": [
61
+ {
62
+ "name": "stdout",
63
+ "output_type": "stream",
64
+ "text": [
65
+ "Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (2.0.0)\n",
66
+ "Requirement already satisfied: pandas in ./venv/lib/python3.10/site-packages (1.5.3)\n",
67
+ "Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (1.24.1)\n",
68
+ "Requirement already satisfied: tensorboard in ./venv/lib/python3.10/site-packages (2.12.0)\n",
69
+ "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in ./venv/lib/python3.10/site-packages (from torch) (2.14.3)\n",
70
+ "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in ./venv/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
71
+ "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in ./venv/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
72
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
73
+ "Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch) (3.0)\n",
74
+ "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in ./venv/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
75
+ "Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from torch) (3.10.4)\n",
76
+ "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.91)\n",
77
+ "Requirement already satisfied: typing-extensions in ./venv/lib/python3.10/site-packages (from torch) (4.5.0)\n",
78
+ "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in ./venv/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
79
+ "Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch) (1.11.1)\n",
80
+ "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
81
+ "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in ./venv/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
82
+ "Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch) (3.1.2)\n",
83
+ "Requirement already satisfied: triton==2.0.0 in ./venv/lib/python3.10/site-packages (from torch) (2.0.0)\n",
84
+ "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in ./venv/lib/python3.10/site-packages (from torch) (11.7.101)\n",
85
+ "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
86
+ "Requirement already satisfied: wheel in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.40.0)\n",
87
+ "Requirement already satisfied: setuptools in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (65.5.0)\n",
88
+ "Requirement already satisfied: lit in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.0)\n",
89
+ "Requirement already satisfied: cmake in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.26.1)\n",
90
+ "Requirement already satisfied: python-dateutil>=2.8.1 in ./venv/lib/python3.10/site-packages (from pandas) (2.8.2)\n",
91
+ "Requirement already satisfied: pytz>=2020.1 in ./venv/lib/python3.10/site-packages (from pandas) (2023.2)\n",
92
+ "Requirement already satisfied: requests<3,>=2.21.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.28.2)\n",
93
+ "Requirement already satisfied: werkzeug>=1.0.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.2.3)\n",
94
+ "Requirement already satisfied: google-auth<3,>=1.6.3 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.16.3)\n",
95
+ "Requirement already satisfied: protobuf>=3.19.6 in ./venv/lib/python3.10/site-packages (from tensorboard) (4.22.1)\n",
96
+ "Requirement already satisfied: markdown>=2.6.8 in ./venv/lib/python3.10/site-packages (from tensorboard) (3.4.3)\n",
97
+ "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.4.6)\n",
98
+ "Requirement already satisfied: grpcio>=1.48.2 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.51.3)\n",
99
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.7.0)\n",
100
+ "Requirement already satisfied: absl-py>=0.4 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.4.0)\n",
101
+ "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.8.1)\n",
102
+ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (5.3.0)\n",
103
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (0.2.8)\n",
104
+ "Requirement already satisfied: six>=1.9.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (1.16.0)\n",
105
+ "Requirement already satisfied: rsa<5,>=3.1.4 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (4.9)\n",
106
+ "Requirement already satisfied: requests-oauthlib>=0.7.0 in ./venv/lib/python3.10/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (1.3.1)\n",
107
+ "Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.4)\n",
108
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (1.26.15)\n",
109
+ "Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.1.0)\n",
110
+ "Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (2022.12.7)\n",
111
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in ./venv/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.2)\n",
112
+ "Requirement already satisfied: mpmath>=0.19 in ./venv/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
113
+ "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in ./venv/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.4.8)\n",
114
+ "Requirement already satisfied: oauthlib>=3.0.0 in ./venv/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (3.2.2)\n",
115
+ "\n",
116
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n",
117
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
118
+ "Note: you may need to restart the kernel to use updated packages.\n"
119
+ ]
120
+ }
121
+ ],
122
+ "source": [
123
+ "%pip install torch pandas numpy tensorboard"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 43,
129
  "metadata": {},
130
  "outputs": [
131
  {
132
  "data": {
133
  "text/plain": [
134
+ "<torch._C.Generator at 0x7fef50768610>"
135
  ]
136
  },
137
+ "execution_count": 43,
138
  "metadata": {},
139
  "output_type": "execute_result"
140
  }
 
155
  },
156
  {
157
  "cell_type": "code",
158
+ "execution_count": 44,
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
 
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": 45,
173
  "metadata": {},
174
+ "outputs": [
175
+ {
176
+ "name": "stdout",
177
+ "output_type": "stream",
178
+ "text": [
179
+ "1115394 chars of data\n"
180
+ ]
181
+ }
182
+ ],
183
  "source": [
184
  "# Tensorify data, put it in dataset\n",
185
  "data = torch.tensor(encode_text(text), dtype=torch.int32)\n",
186
  "\n",
187
+ "test_split_idx = int(0.8 * len(data))\n",
188
+ "val_split_idx = int(0.9 * len(data))\n",
189
+ "train_data = data[:test_split_idx]\n",
190
+ "test_data = data[test_split_idx:val_split_idx]\n",
191
+ "val_data = data[val_split_idx:]\n",
192
+ "print(f\"{len(data)} chars of data\")"
193
  ]
194
  },
195
  {
 
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": 46,
206
  "metadata": {},
207
  "outputs": [],
208
  "source": [
 
212
  " self.context_size = context_size\n",
213
  " \n",
214
  " def __len__(self):\n",
215
+ " return len(self.data_tensor) - self.context_size\n",
216
  "\n",
217
  " def __getitem__(self, index):\n",
218
+ " x = self.data_tensor[index:index + self.context_size]\n",
219
+ " y = self.data_tensor[index + 1:index + self.context_size + 1]\n",
 
 
220
  " \n",
 
221
  " return x, y"
222
  ]
223
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  {
225
  "attachments": {},
226
  "cell_type": "markdown",
 
231
  },
232
  {
233
  "cell_type": "code",
234
+ "execution_count": 66,
235
  "metadata": {},
236
  "outputs": [],
237
  "source": [
 
244
  " self.num_heads = num_heads\n",
245
  " self.d_k = embed_dim // num_heads\n",
246
  "\n",
247
+ " self.Q = nn.Linear(embed_dim, embed_dim, bias=False)\n",
248
+ " self.K = nn.Linear(embed_dim, embed_dim, bias=False)\n",
249
+ " self.V = nn.Linear(embed_dim, embed_dim, bias=False)\n",
250
  "\n",
251
  " self.dropout = nn.Dropout(dropout)\n",
252
  " self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
 
253
  "\n",
254
+ " def forward(self, query, key, value, attn_mask=None):\n",
255
  " batch_size = query.size(0)\n",
256
  "\n",
257
  " # Apply linear layers\n",
 
268
  " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, C, C]\n",
269
  "\n",
270
  " # Apply mask, if necessary\n",
271
+ " if attn_mask is not None:\n",
272
  " \"\"\"\n",
273
  " MAY BE WORTH DEBUGGING\n",
274
  "\n",
 
280
  " key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len]\n",
281
  " \"\"\"\n",
282
  " # Apply the mask to attention scores\n",
283
+ " scores = scores.masked_fill(attn_mask, float('-inf'))\n",
284
  "\n",
285
  " # Scale by sqrt(k)\n",
286
  " attn = F.softmax(scores, dim=-1)\n",
 
292
  " out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)\n",
293
  " # Project: give attention \"time to think\". Maybe this should be part of a different module but whatever\n",
294
  " out = self.out_proj(out)\n",
295
+ " return((out, None))\n",
296
  "\n"
297
  ]
298
  },
299
  {
300
  "cell_type": "code",
301
+ "execution_count": 48,
302
  "metadata": {},
303
  "outputs": [],
304
  "source": [
 
307
  " super().__init__()\n",
308
  " self.net = nn.Sequential(\n",
309
  " nn.Linear(embed_dim, 4 * embed_dim),\n",
310
+ " nn.GELU(),\n",
 
311
  " nn.Linear(4 * embed_dim, embed_dim),\n",
312
+ " nn.Dropout(dropout),\n",
313
  " )\n",
314
  "\n",
315
  " def forward(self, x):\n",
 
318
  },
319
  {
320
  "cell_type": "code",
321
+ "execution_count": 60,
322
  "metadata": {},
323
  "outputs": [],
324
  "source": [
 
328
  " super(Block, self).__init__() \n",
329
  " self.register_buffer(\"mask\", mask)\n",
330
  " self.head = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)\n",
331
+ " #self.head = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)\n",
332
  " self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)\n",
333
  " self.ln1 = nn.LayerNorm(embed_dim)\n",
334
  " self.ln2 = nn.LayerNorm(embed_dim)\n",
 
336
  " def forward(self, x):\n",
337
  " # Residual connections\n",
338
  " x = self.ln1(x)\n",
339
+ " attn_output, _ = self.head(x, x, x, attn_mask=self.mask) \n",
340
+ " x = x + attn_output\n",
341
  " out = x + self.ffwd(self.ln2(x))\n",
342
  " return out\n"
343
  ]
344
  },
345
  {
346
  "cell_type": "code",
347
+ "execution_count": 50,
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
351
  "class GPT(nn.Module):\n",
352
+ " def __init__(self, embedding_dim, vocab_size, context_size):\n",
 
353
  " super(GPT, self).__init__()\n",
354
  "\n",
 
355
  " self.embedding_dim = embedding_dim\n",
356
  " self.output_dim = vocab_size\n",
357
  " self.context_size = context_size\n",
358
  "\n",
359
+ " NUM_HEADS=4\n",
360
+ " NUM_LAYERS=4\n",
361
+ " \n",
362
+ " # Initialize layers\n",
363
  " self.tok_embed = nn.Embedding(vocab_size, embedding_dim)\n",
364
  " self.pos_embed = nn.Embedding(context_size, embedding_dim)\n",
365
  "\n",
 
 
 
366
  " mask = torch.tril(torch.ones(self.context_size, self.context_size)).bool()\n",
367
  " mask = ~mask\n",
368
+ " self.register_buffer(\"mask\", mask)\n",
369
  "\n",
370
  " self.blocks = nn.Sequential(\n",
371
+ " *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)]\n",
 
372
  " )\n",
373
  "\n",
374
+ " self.ln_f = nn.LayerNorm(self.embedding_dim)\n",
375
  " # Final feed-forward layer from embeddings\n",
376
+ " self.ffwd = nn.Linear(embedding_dim, out_features=vocab_size, bias=False)\n",
377
  "\n",
378
  " def forward(self, x):\n",
379
  " tok_embed = self.tok_embed(x)\n",
380
+ " pos_embed = self.pos_embed(\n",
381
+ " torch.arange(0, self.context_size, device=\"cuda\")\n",
382
+ " )\n",
383
  " x = tok_embed + pos_embed\n",
384
  "\n",
 
 
385
  " x = self.blocks(x)\n",
386
+ " x = self.ln_f(x)\n",
387
  "\n",
388
+ " logits = self.ffwd(x)\n",
389
+ " return(logits)\n",
390
  " \n",
391
  " def infer(self, x):\n",
392
  " with torch.no_grad():\n",
 
404
  },
405
  {
406
  "cell_type": "code",
407
+ "execution_count": 51,
408
  "metadata": {},
409
  "outputs": [],
410
  "source": [
411
  "def compute_loss(model, criterion, x, y):\n",
412
  " logits = model(x)\n",
413
+ " B,C,V = logits.shape\n",
414
+ " logits = logits.view(B*C, V)\n",
415
+ " y = y.view(B*C)\n",
416
+ " loss = F.cross_entropy(logits, y.long())\n",
417
  " return loss"
418
  ]
419
  },
420
  {
421
  "cell_type": "code",
422
+ "execution_count": 67,
423
  "metadata": {},
424
  "outputs": [],
425
  "source": [
426
+ "LR=3e-4\n",
 
 
 
 
 
427
  "\n",
428
  "train_dataset = TextDataset(train_data, BLOCK_SIZE)\n",
429
+ "test_dataset = TextDataset(test_data, BLOCK_SIZE)\n",
430
  "\n",
431
  "# Janky training code\n",
432
  "model = GPT(\n",
433
  " embedding_dim=EMBEDDING_NDIM, \n",
434
  " vocab_size=VOCAB_SIZE,\n",
435
  " context_size=BLOCK_SIZE,\n",
 
436
  " )\n",
437
  "\n",
438
  "model = model.to('cuda')\n",
439
  "optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
440
+ "#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
441
+ "criterion = F.cross_entropy\n",
442
+ "\n",
443
+ "global_step = 0"
444
  ]
445
  },
446
  {
447
  "cell_type": "code",
448
+ "execution_count": 68,
449
  "metadata": {},
450
  "outputs": [
451
  {
452
  "name": "stdout",
453
  "output_type": "stream",
454
  "text": [
455
+ "Step 0; loss: 4.62758731842041\n",
456
+ "Step 100; loss: 2.5372843742370605\n",
457
+ "Step 200; loss: 2.486722946166992\n",
458
+ "Step 300; loss: 2.3916263580322266\n",
459
+ "Step 400; loss: 2.269087314605713\n",
460
+ "Step 500; loss: 2.1484358310699463\n",
461
+ "Step 600; loss: 2.057586193084717\n",
462
+ "Step 700; loss: 1.9845455884933472\n",
463
+ "Step 800; loss: 1.910020351409912\n",
464
+ "Step 900; loss: 1.8550803661346436\n",
465
+ "Step 1000; loss: 1.8193731307983398\n",
466
+ "Step 1100; loss: 1.767741322517395\n",
467
+ "Step 1200; loss: 1.7612113952636719\n",
468
+ "Step 1300; loss: 1.7009034156799316\n",
469
+ "Step 1400; loss: 1.6827564239501953\n",
470
+ "Step 1500; loss: 1.6604313850402832\n",
471
+ "Step 1600; loss: 1.633068323135376\n",
472
+ "Step 1700; loss: 1.6335963010787964\n",
473
+ "Step 1800; loss: 1.6095472574234009\n",
474
+ "Step 1900; loss: 1.6086715459823608\n",
475
+ "Step 2000; loss: 1.5876469612121582\n",
476
+ "Step 2100; loss: 1.5713247060775757\n",
477
+ "Step 2200; loss: 1.5546257495880127\n",
478
+ "Step 2300; loss: 1.5589814186096191\n",
479
+ "Step 2400; loss: 1.5507397651672363\n",
480
+ "Step 2500; loss: 1.5470337867736816\n",
481
+ "Step 2600; loss: 1.547551155090332\n",
482
+ "Step 2700; loss: 1.5338884592056274\n",
483
+ "Step 2800; loss: 1.5179914236068726\n",
484
+ "Step 2900; loss: 1.5240544080734253\n",
485
+ "Step 3000; loss: 1.5162924528121948\n",
486
+ "Step 3100; loss: 1.5197933912277222\n",
487
+ "Step 3200; loss: 1.5107413530349731\n",
488
+ "Step 3300; loss: 1.5017006397247314\n",
489
+ "Step 3400; loss: 1.4874128103256226\n",
490
+ "Step 3500; loss: 1.4917751550674438\n",
491
+ "Step 3600; loss: 1.5251762866973877\n",
492
+ "Step 3700; loss: 1.4957225322723389\n",
493
+ "Step 3800; loss: 1.507473111152649\n",
494
+ "Step 3900; loss: 1.4815101623535156\n",
495
+ "Step 4000; loss: 1.4824676513671875\n",
496
+ "Step 4100; loss: 1.4799575805664062\n",
497
+ "Step 4200; loss: 1.4820805788040161\n",
498
+ "Step 4300; loss: 1.4852553606033325\n",
499
+ "Step 4400; loss: 1.469815731048584\n",
500
+ "Step 4500; loss: 1.4853312969207764\n",
501
+ "Step 4600; loss: 1.4830256700515747\n",
502
+ "Step 4700; loss: 1.468559741973877\n",
503
+ "Step 4800; loss: 1.4680243730545044\n",
504
+ "Step 4900; loss: 1.464580774307251\n"
505
  ]
506
  }
507
  ],
508
  "source": [
509
+ "from torch.utils.tensorboard import SummaryWriter\n",
510
+ "\n",
511
  "EPOCHS = 1\n",
512
  "STEPS = 5000\n",
513
  "VAL_INTERVAL = 100\n",
514
  "\n",
 
515
  "model.train()\n",
516
  "\n",
517
  "train_dataloader = DataLoader(\n",
 
523
  "\n",
524
  "test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4, shuffle=True)\n",
525
  "\n",
526
+ "writer = SummaryWriter()\n",
527
+ "\n",
528
  "step = 0\n",
529
+ "\n",
530
  "for epoch in range(EPOCHS):\n",
531
  " for data, target in train_dataloader:\n",
532
  " data = data.to('cuda')\n",
 
538
  " optimizer.zero_grad()\n",
539
  " loss.backward()\n",
540
  " optimizer.step()\n",
541
+ " #scheduler.step()\n",
542
  "\n",
543
+ " writer.add_scalar(\"Loss/train\", loss.cpu().detach().numpy(), global_step)\n",
544
+ " global_step += 1\n",
545
  "\n",
546
+ " # TODO!!! WTF???\n",
547
  " if step % VAL_INTERVAL == 0:\n",
548
+ " total_loss = 0\n",
549
+ " total_samples = 0\n",
550
+ "\n",
551
  " with torch.no_grad():\n",
552
  " model.eval()\n",
553
  " for x, y in test_dataloader:\n",
 
560
  " if total_samples > 10:\n",
561
  " break\n",
562
  "\n",
563
+ " model.train()\n",
564
+ " average_loss = total_loss / total_samples\n",
565
+ "\n",
566
+ " print(f\"Step {step}; loss: {average_loss}\")\n",
567
+ " writer.add_scalar(\"Loss/val\", average_loss, global_step)\n",
568
+ "\n",
569
  "\n",
570
  " step += 1\n",
571
  " if step >= STEPS:\n",
572
+ " break\n",
573
+ "\n",
574
+ "writer.close()\n"
575
  ]
576
  },
577
  {
578
  "cell_type": "code",
579
+ "execution_count": 69,
580
  "metadata": {},
581
  "outputs": [],
582
  "source": [
 
585
  },
586
  {
587
  "cell_type": "code",
588
+ "execution_count": 70,
589
  "metadata": {},
590
  "outputs": [],
591
  "source": [
 
619
  },
620
  {
621
  "cell_type": "code",
622
+ "execution_count": 26,
623
  "metadata": {},
624
  "outputs": [
625
  {
626
  "data": {
627
  "text/plain": [
628
+ "841"
629
  ]
630
  },
631
+ "execution_count": 26,
632
  "metadata": {},
633
  "output_type": "execute_result"
634
  }
 
640
  },
641
  {
642
  "cell_type": "code",
643
+ "execution_count": 57,
644
  "metadata": {},
645
  "outputs": [
646
  {
647
  "name": "stdout",
648
  "output_type": "stream",
649
  "text": [
650
+ "1.7774584962397206\n"
651
  ]
652
  }
653
  ],
 
656
  "total_loss = 0.0\n",
657
  "total_samples = 0\n",
658
  "\n",
659
+ "val_dataset = TextDataset(val_data, BLOCK_SIZE)\n",
660
+ "val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
661
  "with torch.no_grad():\n",
662
+ " for x, y in val_dataloader:\n",
663
  " x = x.to(\"cuda\")\n",
664
  " y = y.to(\"cuda\")\n",
665
  "\n",
666
  " batch_loss = compute_loss(model, criterion, x, y)\n",
667
  " total_loss += batch_loss.item() * x.size(0)\n",
668
  " total_samples += x.size(0)\n",
669
+ " if total_samples > 100000:\n",
670
  " break\n",
671
  "\n",
672
  " average_loss = total_loss / total_samples\n",
 
675
  },
676
  {
677
  "cell_type": "code",
678
+ "execution_count": 71,
679
  "metadata": {},
680
+ "outputs": [
681
+ {
682
+ "data": {
683
+ "text/plain": [
684
+ "3286528"
685
+ ]
686
+ },
687
+ "execution_count": 71,
688
+ "metadata": {},
689
+ "output_type": "execute_result"
690
+ }
691
+ ],
692
+ "source": [
693
+ "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
694
+ "num_params"
695
+ ]
696
  },
697
  {
698
  "attachments": {},
699
  "cell_type": "markdown",
700
  "metadata": {},
701
  "source": [
702
+ "Finally, we generate. NOTE: seeds shorter than 256 chars have nonsense until you reach the context window. I think it's because Karpathy jammed the whole Shakespeare into one file with no act/scene breaks and both he and I didn't split it, so there's only one padding that the model sees, ever. TODO: fix this in the data loading step"
703
  ]
704
  },
705
  {
706
  "cell_type": "code",
707
+ "execution_count": 58,
708
  "metadata": {},
709
  "outputs": [
710
  {
711
  "name": "stdout",
712
  "output_type": "stream",
713
  "text": [
714
+ "Tutus, to Marcius, noble Marcius\n",
715
+ "Made to my voices! doing and hangs upon them!\n",
716
+ "Take it to down our foes and hates with stain,\n",
717
+ "Which thus follows slay with on I meland,\n",
718
+ "What I am after her to her fearful haunt it?\n",
719
+ "\n",
720
+ "PAULINA:\n",
721
+ "But you are well to hold the king.\n",
722
+ "\n",
723
+ "ISABELLA:\n",
724
+ "And I will not go royalty to thy hand.\n",
725
+ "\n",
726
+ "LUCIO:\n",
727
+ "Since I do not well in such goodly talk of.\n",
728
+ "I think I have a stay of it!\n",
729
+ "\n",
730
+ "HENRY BOLINGBROKE:\n",
731
+ "Who say I hate been a day's mind;\n",
732
+ "Till we here and so very little and way,\n",
733
+ "And wash the city has nest seen the feast.\n",
734
+ "\n",
735
+ "DUCHESS OF YORK:\n",
736
+ "No, by the matter.\n",
737
+ "\n",
738
+ "ISABELLA:\n",
739
+ "Flitter than desire never yet looks so.\n",
740
+ "\n",
741
+ "HENRY BOLINGBROKE:\n",
742
+ "I am not possible perceived\n",
743
+ "And both place, where I may not rafes,\n",
744
+ "And like me one air. What you'll your love day?\n",
745
+ "\n",
746
+ "KING RICHARD II:\n",
747
+ "Then be thou--\n",
748
+ "\n",
749
+ "GLOUCESTER:\n",
750
+ "No, Lord Hastings:\n",
751
+ "Else queen, though my trowbers grands me to-morrow\n",
752
+ "Here to Bolingbroke's match;\n",
753
+ "When the your life and spur at homely speak.\n",
754
+ "\n",
755
+ "BUCKINGHAM:\n",
756
+ "My father was I follow: if you be your your kingdom,\n",
757
+ "My approbations an"
758
  ]
759
  }
760
  ],
761
  "source": [
762
  "g_cuda = torch.Generator(device='cuda')\n",
763
  "\n",
764
+ "seed = \"\"\"\n",
765
+ "Plot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\n",
766
+ "\"\"\"\n",
767
+ "\n",
768
+ "contexts = torch.tensor(encode_text(seed), dtype=torch.int32).to('cuda')\n",
769
+ "GEN_LENGTH=1024\n",
770
  "\n",
771
  "model.eval()\n",
772
  "for i in range(GEN_LENGTH):\n",
 
774
  " # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
775
  " #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
776
  " x = contexts[-BLOCK_SIZE:]\n",
777
+ " if x.size(0) < BLOCK_SIZE:\n",
778
+ " x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), \"constant\", 0).unsqueeze(0) # B*T\n",
779
+ " else:\n",
780
+ " x = x.unsqueeze(0)\n",
781
+ "\n",
782
  " preds = model.infer(x)\n",
783
  " preds = preds.squeeze(0)\n",
784
  " probs = torch.softmax(preds, dim=-1)\n",
785
  "\n",
786
  " # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
787
+ " # next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
788
+ " next_char = torch.multinomial(torch.exp(preds[-1, :]), num_samples=1, generator=g_cuda)\n",
789
+ "\n",
790
  " #context = torch.cat(context, next_char)\n",
791
  " contexts = torch.cat((contexts, next_char), dim=0)\n",
792
  " print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",