Fix training objective; lower model size
Browse files- .gitignore +3 -0
- gpt.ipynb +315 -203
.gitignore
CHANGED
@@ -2,6 +2,9 @@
|
|
2 |
checkpoints/
|
3 |
datasets/
|
4 |
|
|
|
|
|
|
|
5 |
# Byte-compiled / optimized / DLL files
|
6 |
__pycache__/
|
7 |
*.py[cod]
|
|
|
2 |
checkpoints/
|
3 |
datasets/
|
4 |
|
5 |
+
# Training Tensorboard runs
|
6 |
+
runs/
|
7 |
+
|
8 |
# Byte-compiled / optimized / DLL files
|
9 |
__pycache__/
|
10 |
*.py[cod]
|
gpt.ipynb
CHANGED
@@ -10,20 +10,19 @@
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
-
"execution_count":
|
14 |
"metadata": {},
|
15 |
"outputs": [],
|
16 |
"source": [
|
17 |
"import os\n",
|
18 |
"\n",
|
19 |
-
"# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
|
20 |
"if not os.path.isfile(\"./datasets/corpora/shakespeare.txt\"):\n",
|
21 |
-
" !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
22 |
]
|
23 |
},
|
24 |
{
|
25 |
"cell_type": "code",
|
26 |
-
"execution_count":
|
27 |
"metadata": {},
|
28 |
"outputs": [],
|
29 |
"source": [
|
@@ -31,6 +30,21 @@
|
|
31 |
" text = f.read()"
|
32 |
]
|
33 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
{
|
35 |
"attachments": {},
|
36 |
"cell_type": "markdown",
|
@@ -41,16 +55,86 @@
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"metadata": {},
|
46 |
"outputs": [
|
47 |
{
|
48 |
"data": {
|
49 |
"text/plain": [
|
50 |
-
"<torch._C.Generator at
|
51 |
]
|
52 |
},
|
53 |
-
"execution_count":
|
54 |
"metadata": {},
|
55 |
"output_type": "execute_result"
|
56 |
}
|
@@ -71,7 +155,7 @@
|
|
71 |
},
|
72 |
{
|
73 |
"cell_type": "code",
|
74 |
-
"execution_count":
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
@@ -85,16 +169,27 @@
|
|
85 |
},
|
86 |
{
|
87 |
"cell_type": "code",
|
88 |
-
"execution_count":
|
89 |
"metadata": {},
|
90 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
"source": [
|
92 |
"# Tensorify data, put it in dataset\n",
|
93 |
"data = torch.tensor(encode_text(text), dtype=torch.int32)\n",
|
94 |
"\n",
|
95 |
-
"
|
96 |
-
"
|
97 |
-
"
|
|
|
|
|
|
|
98 |
]
|
99 |
},
|
100 |
{
|
@@ -107,7 +202,7 @@
|
|
107 |
},
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
-
"execution_count":
|
111 |
"metadata": {},
|
112 |
"outputs": [],
|
113 |
"source": [
|
@@ -117,92 +212,15 @@
|
|
117 |
" self.context_size = context_size\n",
|
118 |
" \n",
|
119 |
" def __len__(self):\n",
|
120 |
-
" return len(self.data_tensor)\n",
|
121 |
"\n",
|
122 |
" def __getitem__(self, index):\n",
|
123 |
-
"
|
124 |
-
"
|
125 |
-
" else:\n",
|
126 |
-
" x = self.data_tensor[index - self.context_size:index]\n",
|
127 |
" \n",
|
128 |
-
" y = self.data_tensor[index]\n",
|
129 |
" return x, y"
|
130 |
]
|
131 |
},
|
132 |
-
{
|
133 |
-
"attachments": {},
|
134 |
-
"cell_type": "markdown",
|
135 |
-
"metadata": {},
|
136 |
-
"source": [
|
137 |
-
"NOTE 2023-03-25: I think this is bugged, and that's the reason the training loss is so damn high. Testing:"
|
138 |
-
]
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"cell_type": "code",
|
142 |
-
"execution_count": 34,
|
143 |
-
"metadata": {},
|
144 |
-
"outputs": [
|
145 |
-
{
|
146 |
-
"name": "stdout",
|
147 |
-
"output_type": "stream",
|
148 |
-
"text": [
|
149 |
-
"Step 0:\n",
|
150 |
-
"[0, 0, 0, 0, 0, 0, 0, 0]\n",
|
151 |
-
"---\n",
|
152 |
-
"[0, 0, 0, 0, 0, 0, 0, 70]\n",
|
153 |
-
"---\n",
|
154 |
-
"['F', 'i']\n",
|
155 |
-
"Step 1:\n",
|
156 |
-
"[0, 0, 0, 0, 0, 0, 70, 105]\n",
|
157 |
-
"---\n",
|
158 |
-
"[0, 0, 0, 0, 0, 70, 105, 114]\n",
|
159 |
-
"---\n",
|
160 |
-
"['r', 's']\n",
|
161 |
-
"Step 2:\n",
|
162 |
-
"[0, 0, 0, 0, 70, 105, 114, 115]\n",
|
163 |
-
"---\n",
|
164 |
-
"[0, 0, 0, 70, 105, 114, 115, 116]\n",
|
165 |
-
"---\n",
|
166 |
-
"['t', ' ']\n",
|
167 |
-
"Step 3:\n",
|
168 |
-
"[0, 0, 70, 105, 114, 115, 116, 32]\n",
|
169 |
-
"---\n",
|
170 |
-
"[0, 70, 105, 114, 115, 116, 32, 67]\n",
|
171 |
-
"---\n",
|
172 |
-
"['C', 'i']\n",
|
173 |
-
"Step 4:\n",
|
174 |
-
"[70, 105, 114, 115, 116, 32, 67, 105]\n",
|
175 |
-
"---\n",
|
176 |
-
"[105, 114, 115, 116, 32, 67, 105, 116]\n",
|
177 |
-
"---\n",
|
178 |
-
"['t', 'i']\n",
|
179 |
-
"Step 5:\n",
|
180 |
-
"[114, 115, 116, 32, 67, 105, 116, 105]\n",
|
181 |
-
"---\n",
|
182 |
-
"[115, 116, 32, 67, 105, 116, 105, 122]\n",
|
183 |
-
"---\n",
|
184 |
-
"['z', 'e']\n"
|
185 |
-
]
|
186 |
-
}
|
187 |
-
],
|
188 |
-
"source": [
|
189 |
-
"train_dataset = TextDataset(train_data, 8)\n",
|
190 |
-
"train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)\n",
|
191 |
-
"\n",
|
192 |
-
"step = 0\n",
|
193 |
-
"for x, y in train_dataloader:\n",
|
194 |
-
" print(f\"Step {step}:\")\n",
|
195 |
-
" for b in x.tolist():\n",
|
196 |
-
" print(b)\n",
|
197 |
-
" print(\"---\")\n",
|
198 |
-
"\n",
|
199 |
-
" print(decode_text(y.tolist()))\n",
|
200 |
-
" step += 1\n",
|
201 |
-
" if step > 5:\n",
|
202 |
-
" break\n",
|
203 |
-
"\n"
|
204 |
-
]
|
205 |
-
},
|
206 |
{
|
207 |
"attachments": {},
|
208 |
"cell_type": "markdown",
|
@@ -213,7 +231,7 @@
|
|
213 |
},
|
214 |
{
|
215 |
"cell_type": "code",
|
216 |
-
"execution_count":
|
217 |
"metadata": {},
|
218 |
"outputs": [],
|
219 |
"source": [
|
@@ -226,15 +244,14 @@
|
|
226 |
" self.num_heads = num_heads\n",
|
227 |
" self.d_k = embed_dim // num_heads\n",
|
228 |
"\n",
|
229 |
-
" self.Q = nn.Linear(embed_dim, embed_dim, bias=
|
230 |
-
" self.K = nn.Linear(embed_dim, embed_dim, bias=
|
231 |
-
" self.V = nn.Linear(embed_dim, embed_dim, bias=
|
232 |
"\n",
|
233 |
" self.dropout = nn.Dropout(dropout)\n",
|
234 |
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
|
235 |
-
" nn.init.kaiming_normal_(self.out_proj.weight, mode='fan_in', nonlinearity='linear')\n",
|
236 |
"\n",
|
237 |
-
" def forward(self, query, key, value,
|
238 |
" batch_size = query.size(0)\n",
|
239 |
"\n",
|
240 |
" # Apply linear layers\n",
|
@@ -251,7 +268,7 @@
|
|
251 |
" scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, C, C]\n",
|
252 |
"\n",
|
253 |
" # Apply mask, if necessary\n",
|
254 |
-
" if
|
255 |
" \"\"\"\n",
|
256 |
" MAY BE WORTH DEBUGGING\n",
|
257 |
"\n",
|
@@ -263,7 +280,7 @@
|
|
263 |
" key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len]\n",
|
264 |
" \"\"\"\n",
|
265 |
" # Apply the mask to attention scores\n",
|
266 |
-
" scores = scores.masked_fill(
|
267 |
"\n",
|
268 |
" # Scale by sqrt(k)\n",
|
269 |
" attn = F.softmax(scores, dim=-1)\n",
|
@@ -275,13 +292,13 @@
|
|
275 |
" out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)\n",
|
276 |
" # Project: give attention \"time to think\". Maybe this should be part of a different module but whatever\n",
|
277 |
" out = self.out_proj(out)\n",
|
278 |
-
" return(out)\n",
|
279 |
"\n"
|
280 |
]
|
281 |
},
|
282 |
{
|
283 |
"cell_type": "code",
|
284 |
-
"execution_count":
|
285 |
"metadata": {},
|
286 |
"outputs": [],
|
287 |
"source": [
|
@@ -290,9 +307,9 @@
|
|
290 |
" super().__init__()\n",
|
291 |
" self.net = nn.Sequential(\n",
|
292 |
" nn.Linear(embed_dim, 4 * embed_dim),\n",
|
293 |
-
" nn.
|
294 |
-
" nn.Dropout(dropout)\n",
|
295 |
" nn.Linear(4 * embed_dim, embed_dim),\n",
|
|
|
296 |
" )\n",
|
297 |
"\n",
|
298 |
" def forward(self, x):\n",
|
@@ -301,7 +318,7 @@
|
|
301 |
},
|
302 |
{
|
303 |
"cell_type": "code",
|
304 |
-
"execution_count":
|
305 |
"metadata": {},
|
306 |
"outputs": [],
|
307 |
"source": [
|
@@ -311,6 +328,7 @@
|
|
311 |
" super(Block, self).__init__() \n",
|
312 |
" self.register_buffer(\"mask\", mask)\n",
|
313 |
" self.head = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)\n",
|
|
|
314 |
" self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)\n",
|
315 |
" self.ln1 = nn.LayerNorm(embed_dim)\n",
|
316 |
" self.ln2 = nn.LayerNorm(embed_dim)\n",
|
@@ -318,58 +336,57 @@
|
|
318 |
" def forward(self, x):\n",
|
319 |
" # Residual connections\n",
|
320 |
" x = self.ln1(x)\n",
|
321 |
-
"
|
|
|
322 |
" out = x + self.ffwd(self.ln2(x))\n",
|
323 |
" return out\n"
|
324 |
]
|
325 |
},
|
326 |
{
|
327 |
"cell_type": "code",
|
328 |
-
"execution_count":
|
329 |
"metadata": {},
|
330 |
"outputs": [],
|
331 |
"source": [
|
332 |
"class GPT(nn.Module):\n",
|
333 |
-
" def __init__(self, embedding_dim, vocab_size, context_size
|
334 |
-
" # Inherit PyTorch stuff\n",
|
335 |
" super(GPT, self).__init__()\n",
|
336 |
"\n",
|
337 |
-
" # Save variables for later\n",
|
338 |
" self.embedding_dim = embedding_dim\n",
|
339 |
" self.output_dim = vocab_size\n",
|
340 |
" self.context_size = context_size\n",
|
341 |
"\n",
|
342 |
-
"
|
|
|
|
|
|
|
343 |
" self.tok_embed = nn.Embedding(vocab_size, embedding_dim)\n",
|
344 |
" self.pos_embed = nn.Embedding(context_size, embedding_dim)\n",
|
345 |
"\n",
|
346 |
-
" NUM_HEADS=6\n",
|
347 |
-
" NUM_LAYERS=6\n",
|
348 |
-
" \n",
|
349 |
" mask = torch.tril(torch.ones(self.context_size, self.context_size)).bool()\n",
|
350 |
" mask = ~mask\n",
|
351 |
-
" self.register_buffer(mask)\n",
|
352 |
"\n",
|
353 |
" self.blocks = nn.Sequential(\n",
|
354 |
-
" *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask) for _ in range(NUM_LAYERS)]
|
355 |
-
" nn.Dropout(0.2)\n",
|
356 |
" )\n",
|
357 |
"\n",
|
|
|
358 |
" # Final feed-forward layer from embeddings\n",
|
359 |
-
" self.ffwd = nn.Linear(embedding_dim, out_features=vocab_size)\n",
|
360 |
"\n",
|
361 |
" def forward(self, x):\n",
|
362 |
" tok_embed = self.tok_embed(x)\n",
|
363 |
-
"
|
364 |
-
"
|
|
|
365 |
" x = tok_embed + pos_embed\n",
|
366 |
"\n",
|
367 |
-
" # The actual attention is all you need here!\n",
|
368 |
-
" # B*C*C cutting out the future\n",
|
369 |
" x = self.blocks(x)\n",
|
|
|
370 |
"\n",
|
371 |
-
"
|
372 |
-
" return(
|
373 |
" \n",
|
374 |
" def infer(self, x):\n",
|
375 |
" with torch.no_grad():\n",
|
@@ -387,93 +404,114 @@
|
|
387 |
},
|
388 |
{
|
389 |
"cell_type": "code",
|
390 |
-
"execution_count":
|
391 |
"metadata": {},
|
392 |
"outputs": [],
|
393 |
"source": [
|
394 |
"def compute_loss(model, criterion, x, y):\n",
|
395 |
" logits = model(x)\n",
|
396 |
-
"
|
397 |
-
"
|
398 |
-
"
|
|
|
399 |
" return loss"
|
400 |
]
|
401 |
},
|
402 |
{
|
403 |
"cell_type": "code",
|
404 |
-
"execution_count":
|
405 |
"metadata": {},
|
406 |
"outputs": [],
|
407 |
"source": [
|
408 |
-
"
|
409 |
-
"VOCAB_SIZE = 128\n",
|
410 |
-
"BATCH_SIZE=64\n",
|
411 |
-
"# \"Context window\"\n",
|
412 |
-
"BLOCK_SIZE=256\n",
|
413 |
-
"LR=1e-3\n",
|
414 |
"\n",
|
415 |
"train_dataset = TextDataset(train_data, BLOCK_SIZE)\n",
|
416 |
-
"test_dataset = TextDataset(
|
417 |
"\n",
|
418 |
"# Janky training code\n",
|
419 |
"model = GPT(\n",
|
420 |
" embedding_dim=EMBEDDING_NDIM, \n",
|
421 |
" vocab_size=VOCAB_SIZE,\n",
|
422 |
" context_size=BLOCK_SIZE,\n",
|
423 |
-
" lr=LR\n",
|
424 |
" )\n",
|
425 |
"\n",
|
426 |
"model = model.to('cuda')\n",
|
427 |
"optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
|
428 |
-
"#
|
429 |
-
"
|
430 |
-
"
|
|
|
431 |
]
|
432 |
},
|
433 |
{
|
434 |
"cell_type": "code",
|
435 |
-
"execution_count":
|
436 |
"metadata": {},
|
437 |
"outputs": [
|
438 |
{
|
439 |
"name": "stdout",
|
440 |
"output_type": "stream",
|
441 |
"text": [
|
442 |
-
"Step 0; loss:
|
443 |
-
"Step 100; loss:
|
444 |
-
"Step 200; loss:
|
445 |
-
"Step 300; loss:
|
446 |
-
"Step 400; loss:
|
447 |
-
"Step 500; loss:
|
448 |
-
"Step 600; loss:
|
449 |
-
"Step 700; loss:
|
450 |
-
"Step 800; loss:
|
451 |
-
"Step 900; loss:
|
452 |
-
"Step 1000; loss:
|
453 |
-
"Step 1100; loss:
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
"\
|
462 |
-
"
|
463 |
-
"
|
464 |
-
"
|
465 |
-
"
|
466 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
]
|
468 |
}
|
469 |
],
|
470 |
"source": [
|
471 |
-
"from torch.utils
|
|
|
472 |
"EPOCHS = 1\n",
|
473 |
"STEPS = 5000\n",
|
474 |
"VAL_INTERVAL = 100\n",
|
475 |
"\n",
|
476 |
-
"losses = []\n",
|
477 |
"model.train()\n",
|
478 |
"\n",
|
479 |
"train_dataloader = DataLoader(\n",
|
@@ -485,7 +523,10 @@
|
|
485 |
"\n",
|
486 |
"test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4, shuffle=True)\n",
|
487 |
"\n",
|
|
|
|
|
488 |
"step = 0\n",
|
|
|
489 |
"for epoch in range(EPOCHS):\n",
|
490 |
" for data, target in train_dataloader:\n",
|
491 |
" data = data.to('cuda')\n",
|
@@ -497,11 +538,16 @@
|
|
497 |
" optimizer.zero_grad()\n",
|
498 |
" loss.backward()\n",
|
499 |
" optimizer.step()\n",
|
500 |
-
" scheduler.step()\n",
|
501 |
"\n",
|
502 |
-
"
|
|
|
503 |
"\n",
|
|
|
504 |
" if step % VAL_INTERVAL == 0:\n",
|
|
|
|
|
|
|
505 |
" with torch.no_grad():\n",
|
506 |
" model.eval()\n",
|
507 |
" for x, y in test_dataloader:\n",
|
@@ -514,18 +560,23 @@
|
|
514 |
" if total_samples > 10:\n",
|
515 |
" break\n",
|
516 |
"\n",
|
517 |
-
"
|
518 |
-
"
|
519 |
-
"
|
|
|
|
|
|
|
520 |
"\n",
|
521 |
" step += 1\n",
|
522 |
" if step >= STEPS:\n",
|
523 |
-
" break\n"
|
|
|
|
|
524 |
]
|
525 |
},
|
526 |
{
|
527 |
"cell_type": "code",
|
528 |
-
"execution_count":
|
529 |
"metadata": {},
|
530 |
"outputs": [],
|
531 |
"source": [
|
@@ -534,7 +585,7 @@
|
|
534 |
},
|
535 |
{
|
536 |
"cell_type": "code",
|
537 |
-
"execution_count":
|
538 |
"metadata": {},
|
539 |
"outputs": [],
|
540 |
"source": [
|
@@ -568,16 +619,16 @@
|
|
568 |
},
|
569 |
{
|
570 |
"cell_type": "code",
|
571 |
-
"execution_count":
|
572 |
"metadata": {},
|
573 |
"outputs": [
|
574 |
{
|
575 |
"data": {
|
576 |
"text/plain": [
|
577 |
-
"
|
578 |
]
|
579 |
},
|
580 |
-
"execution_count":
|
581 |
"metadata": {},
|
582 |
"output_type": "execute_result"
|
583 |
}
|
@@ -589,14 +640,14 @@
|
|
589 |
},
|
590 |
{
|
591 |
"cell_type": "code",
|
592 |
-
"execution_count":
|
593 |
"metadata": {},
|
594 |
"outputs": [
|
595 |
{
|
596 |
"name": "stdout",
|
597 |
"output_type": "stream",
|
598 |
"text": [
|
599 |
-
"
|
600 |
]
|
601 |
}
|
602 |
],
|
@@ -605,16 +656,17 @@
|
|
605 |
"total_loss = 0.0\n",
|
606 |
"total_samples = 0\n",
|
607 |
"\n",
|
608 |
-
"
|
|
|
609 |
"with torch.no_grad():\n",
|
610 |
-
" for x, y in
|
611 |
" x = x.to(\"cuda\")\n",
|
612 |
" y = y.to(\"cuda\")\n",
|
613 |
"\n",
|
614 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
615 |
" total_loss += batch_loss.item() * x.size(0)\n",
|
616 |
" total_samples += x.size(0)\n",
|
617 |
-
" if total_samples >
|
618 |
" break\n",
|
619 |
"\n",
|
620 |
" average_loss = total_loss / total_samples\n",
|
@@ -623,44 +675,98 @@
|
|
623 |
},
|
624 |
{
|
625 |
"cell_type": "code",
|
626 |
-
"execution_count":
|
627 |
"metadata": {},
|
628 |
-
"outputs": [
|
629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
},
|
631 |
{
|
632 |
"attachments": {},
|
633 |
"cell_type": "markdown",
|
634 |
"metadata": {},
|
635 |
"source": [
|
636 |
-
"Finally, we generate:"
|
637 |
]
|
638 |
},
|
639 |
{
|
640 |
"cell_type": "code",
|
641 |
-
"execution_count":
|
642 |
"metadata": {},
|
643 |
"outputs": [
|
644 |
{
|
645 |
"name": "stdout",
|
646 |
"output_type": "stream",
|
647 |
"text": [
|
648 |
-
",
|
649 |
-
"
|
650 |
-
"
|
651 |
-
"
|
652 |
-
"I
|
653 |
-
"
|
654 |
-
"
|
655 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
]
|
657 |
}
|
658 |
],
|
659 |
"source": [
|
660 |
"g_cuda = torch.Generator(device='cuda')\n",
|
661 |
"\n",
|
662 |
-
"
|
663 |
-
"
|
|
|
|
|
|
|
|
|
664 |
"\n",
|
665 |
"model.eval()\n",
|
666 |
"for i in range(GEN_LENGTH):\n",
|
@@ -668,13 +774,19 @@
|
|
668 |
" # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
|
669 |
" #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
|
670 |
" x = contexts[-BLOCK_SIZE:]\n",
|
671 |
-
"
|
|
|
|
|
|
|
|
|
672 |
" preds = model.infer(x)\n",
|
673 |
" preds = preds.squeeze(0)\n",
|
674 |
" probs = torch.softmax(preds, dim=-1)\n",
|
675 |
"\n",
|
676 |
" # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
|
677 |
-
" next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
|
|
|
|
|
678 |
" #context = torch.cat(context, next_char)\n",
|
679 |
" contexts = torch.cat((contexts, next_char), dim=0)\n",
|
680 |
" print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
|
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": 39,
|
14 |
"metadata": {},
|
15 |
"outputs": [],
|
16 |
"source": [
|
17 |
"import os\n",
|
18 |
"\n",
|
|
|
19 |
"if not os.path.isfile(\"./datasets/corpora/shakespeare.txt\"):\n",
|
20 |
+
" !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O datasets/corpora/shakespeare.txt"
|
21 |
]
|
22 |
},
|
23 |
{
|
24 |
"cell_type": "code",
|
25 |
+
"execution_count": 40,
|
26 |
"metadata": {},
|
27 |
"outputs": [],
|
28 |
"source": [
|
|
|
30 |
" text = f.read()"
|
31 |
]
|
32 |
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 41,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"# Putting hyperparameters at the top because I learned this the hard way\n",
|
40 |
+
"# 64 * NUM_HEADS\n",
|
41 |
+
"EMBEDDING_NDIM=256\n",
|
42 |
+
"VOCAB_SIZE=128\n",
|
43 |
+
"BATCH_SIZE=64\n",
|
44 |
+
"# \"Context window\"\n",
|
45 |
+
"BLOCK_SIZE=256"
|
46 |
+
]
|
47 |
+
},
|
48 |
{
|
49 |
"attachments": {},
|
50 |
"cell_type": "markdown",
|
|
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
+
"execution_count": 42,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [
|
61 |
+
{
|
62 |
+
"name": "stdout",
|
63 |
+
"output_type": "stream",
|
64 |
+
"text": [
|
65 |
+
"Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (2.0.0)\n",
|
66 |
+
"Requirement already satisfied: pandas in ./venv/lib/python3.10/site-packages (1.5.3)\n",
|
67 |
+
"Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (1.24.1)\n",
|
68 |
+
"Requirement already satisfied: tensorboard in ./venv/lib/python3.10/site-packages (2.12.0)\n",
|
69 |
+
"Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in ./venv/lib/python3.10/site-packages (from torch) (2.14.3)\n",
|
70 |
+
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in ./venv/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
|
71 |
+
"Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in ./venv/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
|
72 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
|
73 |
+
"Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch) (3.0)\n",
|
74 |
+
"Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in ./venv/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
|
75 |
+
"Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from torch) (3.10.4)\n",
|
76 |
+
"Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.91)\n",
|
77 |
+
"Requirement already satisfied: typing-extensions in ./venv/lib/python3.10/site-packages (from torch) (4.5.0)\n",
|
78 |
+
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in ./venv/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
|
79 |
+
"Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch) (1.11.1)\n",
|
80 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
|
81 |
+
"Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in ./venv/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
|
82 |
+
"Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch) (3.1.2)\n",
|
83 |
+
"Requirement already satisfied: triton==2.0.0 in ./venv/lib/python3.10/site-packages (from torch) (2.0.0)\n",
|
84 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in ./venv/lib/python3.10/site-packages (from torch) (11.7.101)\n",
|
85 |
+
"Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
|
86 |
+
"Requirement already satisfied: wheel in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.40.0)\n",
|
87 |
+
"Requirement already satisfied: setuptools in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (65.5.0)\n",
|
88 |
+
"Requirement already satisfied: lit in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.0)\n",
|
89 |
+
"Requirement already satisfied: cmake in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.26.1)\n",
|
90 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in ./venv/lib/python3.10/site-packages (from pandas) (2.8.2)\n",
|
91 |
+
"Requirement already satisfied: pytz>=2020.1 in ./venv/lib/python3.10/site-packages (from pandas) (2023.2)\n",
|
92 |
+
"Requirement already satisfied: requests<3,>=2.21.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.28.2)\n",
|
93 |
+
"Requirement already satisfied: werkzeug>=1.0.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.2.3)\n",
|
94 |
+
"Requirement already satisfied: google-auth<3,>=1.6.3 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.16.3)\n",
|
95 |
+
"Requirement already satisfied: protobuf>=3.19.6 in ./venv/lib/python3.10/site-packages (from tensorboard) (4.22.1)\n",
|
96 |
+
"Requirement already satisfied: markdown>=2.6.8 in ./venv/lib/python3.10/site-packages (from tensorboard) (3.4.3)\n",
|
97 |
+
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.4.6)\n",
|
98 |
+
"Requirement already satisfied: grpcio>=1.48.2 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.51.3)\n",
|
99 |
+
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.7.0)\n",
|
100 |
+
"Requirement already satisfied: absl-py>=0.4 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.4.0)\n",
|
101 |
+
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.8.1)\n",
|
102 |
+
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (5.3.0)\n",
|
103 |
+
"Requirement already satisfied: pyasn1-modules>=0.2.1 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (0.2.8)\n",
|
104 |
+
"Requirement already satisfied: six>=1.9.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (1.16.0)\n",
|
105 |
+
"Requirement already satisfied: rsa<5,>=3.1.4 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (4.9)\n",
|
106 |
+
"Requirement already satisfied: requests-oauthlib>=0.7.0 in ./venv/lib/python3.10/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (1.3.1)\n",
|
107 |
+
"Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.4)\n",
|
108 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (1.26.15)\n",
|
109 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.1.0)\n",
|
110 |
+
"Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (2022.12.7)\n",
|
111 |
+
"Requirement already satisfied: MarkupSafe>=2.1.1 in ./venv/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.2)\n",
|
112 |
+
"Requirement already satisfied: mpmath>=0.19 in ./venv/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
|
113 |
+
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in ./venv/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.4.8)\n",
|
114 |
+
"Requirement already satisfied: oauthlib>=3.0.0 in ./venv/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (3.2.2)\n",
|
115 |
+
"\n",
|
116 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n",
|
117 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
118 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"source": [
|
123 |
+
"%pip install torch pandas numpy tensorboard"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": 43,
|
129 |
"metadata": {},
|
130 |
"outputs": [
|
131 |
{
|
132 |
"data": {
|
133 |
"text/plain": [
|
134 |
+
"<torch._C.Generator at 0x7fef50768610>"
|
135 |
]
|
136 |
},
|
137 |
+
"execution_count": 43,
|
138 |
"metadata": {},
|
139 |
"output_type": "execute_result"
|
140 |
}
|
|
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
+
"execution_count": 44,
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": [
|
|
|
169 |
},
|
170 |
{
|
171 |
"cell_type": "code",
|
172 |
+
"execution_count": 45,
|
173 |
"metadata": {},
|
174 |
+
"outputs": [
|
175 |
+
{
|
176 |
+
"name": "stdout",
|
177 |
+
"output_type": "stream",
|
178 |
+
"text": [
|
179 |
+
"1115394 chars of data\n"
|
180 |
+
]
|
181 |
+
}
|
182 |
+
],
|
183 |
"source": [
|
184 |
"# Tensorify data, put it in dataset\n",
|
185 |
"data = torch.tensor(encode_text(text), dtype=torch.int32)\n",
|
186 |
"\n",
|
187 |
+
"test_split_idx = int(0.8 * len(data))\n",
|
188 |
+
"val_split_idx = int(0.9 * len(data))\n",
|
189 |
+
"train_data = data[:test_split_idx]\n",
|
190 |
+
"test_data = data[test_split_idx:val_split_idx]\n",
|
191 |
+
"val_data = data[val_split_idx:]\n",
|
192 |
+
"print(f\"{len(data)} chars of data\")"
|
193 |
]
|
194 |
},
|
195 |
{
|
|
|
202 |
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
+
"execution_count": 46,
|
206 |
"metadata": {},
|
207 |
"outputs": [],
|
208 |
"source": [
|
|
|
212 |
" self.context_size = context_size\n",
|
213 |
" \n",
|
214 |
" def __len__(self):\n",
|
215 |
+
" return len(self.data_tensor) - self.context_size\n",
|
216 |
"\n",
|
217 |
" def __getitem__(self, index):\n",
|
218 |
+
" x = self.data_tensor[index:index + self.context_size]\n",
|
219 |
+
" y = self.data_tensor[index + 1:index + self.context_size + 1]\n",
|
|
|
|
|
220 |
" \n",
|
|
|
221 |
" return x, y"
|
222 |
]
|
223 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
{
|
225 |
"attachments": {},
|
226 |
"cell_type": "markdown",
|
|
|
231 |
},
|
232 |
{
|
233 |
"cell_type": "code",
|
234 |
+
"execution_count": 66,
|
235 |
"metadata": {},
|
236 |
"outputs": [],
|
237 |
"source": [
|
|
|
244 |
" self.num_heads = num_heads\n",
|
245 |
" self.d_k = embed_dim // num_heads\n",
|
246 |
"\n",
|
247 |
+
" self.Q = nn.Linear(embed_dim, embed_dim, bias=False)\n",
|
248 |
+
" self.K = nn.Linear(embed_dim, embed_dim, bias=False)\n",
|
249 |
+
" self.V = nn.Linear(embed_dim, embed_dim, bias=False)\n",
|
250 |
"\n",
|
251 |
" self.dropout = nn.Dropout(dropout)\n",
|
252 |
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
|
|
|
253 |
"\n",
|
254 |
+
" def forward(self, query, key, value, attn_mask=None):\n",
|
255 |
" batch_size = query.size(0)\n",
|
256 |
"\n",
|
257 |
" # Apply linear layers\n",
|
|
|
268 |
" scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, C, C]\n",
|
269 |
"\n",
|
270 |
" # Apply mask, if necessary\n",
|
271 |
+
" if attn_mask is not None:\n",
|
272 |
" \"\"\"\n",
|
273 |
" MAY BE WORTH DEBUGGING\n",
|
274 |
"\n",
|
|
|
280 |
" key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len]\n",
|
281 |
" \"\"\"\n",
|
282 |
" # Apply the mask to attention scores\n",
|
283 |
+
" scores = scores.masked_fill(attn_mask, float('-inf'))\n",
|
284 |
"\n",
|
285 |
" # Scale by sqrt(k)\n",
|
286 |
" attn = F.softmax(scores, dim=-1)\n",
|
|
|
292 |
" out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)\n",
|
293 |
" # Project: give attention \"time to think\". Maybe this should be part of a different module but whatever\n",
|
294 |
" out = self.out_proj(out)\n",
|
295 |
+
" return((out, None))\n",
|
296 |
"\n"
|
297 |
]
|
298 |
},
|
299 |
{
|
300 |
"cell_type": "code",
|
301 |
+
"execution_count": 48,
|
302 |
"metadata": {},
|
303 |
"outputs": [],
|
304 |
"source": [
|
|
|
307 |
" super().__init__()\n",
|
308 |
" self.net = nn.Sequential(\n",
|
309 |
" nn.Linear(embed_dim, 4 * embed_dim),\n",
|
310 |
+
" nn.GELU(),\n",
|
|
|
311 |
" nn.Linear(4 * embed_dim, embed_dim),\n",
|
312 |
+
" nn.Dropout(dropout),\n",
|
313 |
" )\n",
|
314 |
"\n",
|
315 |
" def forward(self, x):\n",
|
|
|
318 |
},
|
319 |
{
|
320 |
"cell_type": "code",
|
321 |
+
"execution_count": 60,
|
322 |
"metadata": {},
|
323 |
"outputs": [],
|
324 |
"source": [
|
|
|
328 |
" super(Block, self).__init__() \n",
|
329 |
" self.register_buffer(\"mask\", mask)\n",
|
330 |
" self.head = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)\n",
|
331 |
+
" #self.head = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)\n",
|
332 |
" self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)\n",
|
333 |
" self.ln1 = nn.LayerNorm(embed_dim)\n",
|
334 |
" self.ln2 = nn.LayerNorm(embed_dim)\n",
|
|
|
336 |
" def forward(self, x):\n",
|
337 |
" # Residual connections\n",
|
338 |
" x = self.ln1(x)\n",
|
339 |
+
" attn_output, _ = self.head(x, x, x, attn_mask=self.mask) \n",
|
340 |
+
" x = x + attn_output\n",
|
341 |
" out = x + self.ffwd(self.ln2(x))\n",
|
342 |
" return out\n"
|
343 |
]
|
344 |
},
|
345 |
{
|
346 |
"cell_type": "code",
|
347 |
+
"execution_count": 50,
|
348 |
"metadata": {},
|
349 |
"outputs": [],
|
350 |
"source": [
|
351 |
"class GPT(nn.Module):\n",
|
352 |
+
" def __init__(self, embedding_dim, vocab_size, context_size):\n",
|
|
|
353 |
" super(GPT, self).__init__()\n",
|
354 |
"\n",
|
|
|
355 |
" self.embedding_dim = embedding_dim\n",
|
356 |
" self.output_dim = vocab_size\n",
|
357 |
" self.context_size = context_size\n",
|
358 |
"\n",
|
359 |
+
" NUM_HEADS=4\n",
|
360 |
+
" NUM_LAYERS=4\n",
|
361 |
+
" \n",
|
362 |
+
" # Initialize layers\n",
|
363 |
" self.tok_embed = nn.Embedding(vocab_size, embedding_dim)\n",
|
364 |
" self.pos_embed = nn.Embedding(context_size, embedding_dim)\n",
|
365 |
"\n",
|
|
|
|
|
|
|
366 |
" mask = torch.tril(torch.ones(self.context_size, self.context_size)).bool()\n",
|
367 |
" mask = ~mask\n",
|
368 |
+
" self.register_buffer(\"mask\", mask)\n",
|
369 |
"\n",
|
370 |
" self.blocks = nn.Sequential(\n",
|
371 |
+
" *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)]\n",
|
|
|
372 |
" )\n",
|
373 |
"\n",
|
374 |
+
" self.ln_f = nn.LayerNorm(self.embedding_dim)\n",
|
375 |
" # Final feed-forward layer from embeddings\n",
|
376 |
+
" self.ffwd = nn.Linear(embedding_dim, out_features=vocab_size, bias=False)\n",
|
377 |
"\n",
|
378 |
" def forward(self, x):\n",
|
379 |
" tok_embed = self.tok_embed(x)\n",
|
380 |
+
" pos_embed = self.pos_embed(\n",
|
381 |
+
" torch.arange(0, self.context_size, device=\"cuda\")\n",
|
382 |
+
" )\n",
|
383 |
" x = tok_embed + pos_embed\n",
|
384 |
"\n",
|
|
|
|
|
385 |
" x = self.blocks(x)\n",
|
386 |
+
" x = self.ln_f(x)\n",
|
387 |
"\n",
|
388 |
+
" logits = self.ffwd(x)\n",
|
389 |
+
" return(logits)\n",
|
390 |
" \n",
|
391 |
" def infer(self, x):\n",
|
392 |
" with torch.no_grad():\n",
|
|
|
404 |
},
|
405 |
{
|
406 |
"cell_type": "code",
|
407 |
+
"execution_count": 51,
|
408 |
"metadata": {},
|
409 |
"outputs": [],
|
410 |
"source": [
|
411 |
"def compute_loss(model, criterion, x, y):\n",
|
412 |
" logits = model(x)\n",
|
413 |
+
" B,C,V = logits.shape\n",
|
414 |
+
" logits = logits.view(B*C, V)\n",
|
415 |
+
" y = y.view(B*C)\n",
|
416 |
+
" loss = F.cross_entropy(logits, y.long())\n",
|
417 |
" return loss"
|
418 |
]
|
419 |
},
|
420 |
{
|
421 |
"cell_type": "code",
|
422 |
+
"execution_count": 67,
|
423 |
"metadata": {},
|
424 |
"outputs": [],
|
425 |
"source": [
|
426 |
+
"LR=3e-4\n",
|
|
|
|
|
|
|
|
|
|
|
427 |
"\n",
|
428 |
"train_dataset = TextDataset(train_data, BLOCK_SIZE)\n",
|
429 |
+
"test_dataset = TextDataset(test_data, BLOCK_SIZE)\n",
|
430 |
"\n",
|
431 |
"# Janky training code\n",
|
432 |
"model = GPT(\n",
|
433 |
" embedding_dim=EMBEDDING_NDIM, \n",
|
434 |
" vocab_size=VOCAB_SIZE,\n",
|
435 |
" context_size=BLOCK_SIZE,\n",
|
|
|
436 |
" )\n",
|
437 |
"\n",
|
438 |
"model = model.to('cuda')\n",
|
439 |
"optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
|
440 |
+
"#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
|
441 |
+
"criterion = F.cross_entropy\n",
|
442 |
+
"\n",
|
443 |
+
"global_step = 0"
|
444 |
]
|
445 |
},
|
446 |
{
|
447 |
"cell_type": "code",
|
448 |
+
"execution_count": 68,
|
449 |
"metadata": {},
|
450 |
"outputs": [
|
451 |
{
|
452 |
"name": "stdout",
|
453 |
"output_type": "stream",
|
454 |
"text": [
|
455 |
+
"Step 0; loss: 4.62758731842041\n",
|
456 |
+
"Step 100; loss: 2.5372843742370605\n",
|
457 |
+
"Step 200; loss: 2.486722946166992\n",
|
458 |
+
"Step 300; loss: 2.3916263580322266\n",
|
459 |
+
"Step 400; loss: 2.269087314605713\n",
|
460 |
+
"Step 500; loss: 2.1484358310699463\n",
|
461 |
+
"Step 600; loss: 2.057586193084717\n",
|
462 |
+
"Step 700; loss: 1.9845455884933472\n",
|
463 |
+
"Step 800; loss: 1.910020351409912\n",
|
464 |
+
"Step 900; loss: 1.8550803661346436\n",
|
465 |
+
"Step 1000; loss: 1.8193731307983398\n",
|
466 |
+
"Step 1100; loss: 1.767741322517395\n",
|
467 |
+
"Step 1200; loss: 1.7612113952636719\n",
|
468 |
+
"Step 1300; loss: 1.7009034156799316\n",
|
469 |
+
"Step 1400; loss: 1.6827564239501953\n",
|
470 |
+
"Step 1500; loss: 1.6604313850402832\n",
|
471 |
+
"Step 1600; loss: 1.633068323135376\n",
|
472 |
+
"Step 1700; loss: 1.6335963010787964\n",
|
473 |
+
"Step 1800; loss: 1.6095472574234009\n",
|
474 |
+
"Step 1900; loss: 1.6086715459823608\n",
|
475 |
+
"Step 2000; loss: 1.5876469612121582\n",
|
476 |
+
"Step 2100; loss: 1.5713247060775757\n",
|
477 |
+
"Step 2200; loss: 1.5546257495880127\n",
|
478 |
+
"Step 2300; loss: 1.5589814186096191\n",
|
479 |
+
"Step 2400; loss: 1.5507397651672363\n",
|
480 |
+
"Step 2500; loss: 1.5470337867736816\n",
|
481 |
+
"Step 2600; loss: 1.547551155090332\n",
|
482 |
+
"Step 2700; loss: 1.5338884592056274\n",
|
483 |
+
"Step 2800; loss: 1.5179914236068726\n",
|
484 |
+
"Step 2900; loss: 1.5240544080734253\n",
|
485 |
+
"Step 3000; loss: 1.5162924528121948\n",
|
486 |
+
"Step 3100; loss: 1.5197933912277222\n",
|
487 |
+
"Step 3200; loss: 1.5107413530349731\n",
|
488 |
+
"Step 3300; loss: 1.5017006397247314\n",
|
489 |
+
"Step 3400; loss: 1.4874128103256226\n",
|
490 |
+
"Step 3500; loss: 1.4917751550674438\n",
|
491 |
+
"Step 3600; loss: 1.5251762866973877\n",
|
492 |
+
"Step 3700; loss: 1.4957225322723389\n",
|
493 |
+
"Step 3800; loss: 1.507473111152649\n",
|
494 |
+
"Step 3900; loss: 1.4815101623535156\n",
|
495 |
+
"Step 4000; loss: 1.4824676513671875\n",
|
496 |
+
"Step 4100; loss: 1.4799575805664062\n",
|
497 |
+
"Step 4200; loss: 1.4820805788040161\n",
|
498 |
+
"Step 4300; loss: 1.4852553606033325\n",
|
499 |
+
"Step 4400; loss: 1.469815731048584\n",
|
500 |
+
"Step 4500; loss: 1.4853312969207764\n",
|
501 |
+
"Step 4600; loss: 1.4830256700515747\n",
|
502 |
+
"Step 4700; loss: 1.468559741973877\n",
|
503 |
+
"Step 4800; loss: 1.4680243730545044\n",
|
504 |
+
"Step 4900; loss: 1.464580774307251\n"
|
505 |
]
|
506 |
}
|
507 |
],
|
508 |
"source": [
|
509 |
+
"from torch.utils.tensorboard import SummaryWriter\n",
|
510 |
+
"\n",
|
511 |
"EPOCHS = 1\n",
|
512 |
"STEPS = 5000\n",
|
513 |
"VAL_INTERVAL = 100\n",
|
514 |
"\n",
|
|
|
515 |
"model.train()\n",
|
516 |
"\n",
|
517 |
"train_dataloader = DataLoader(\n",
|
|
|
523 |
"\n",
|
524 |
"test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4, shuffle=True)\n",
|
525 |
"\n",
|
526 |
+
"writer = SummaryWriter()\n",
|
527 |
+
"\n",
|
528 |
"step = 0\n",
|
529 |
+
"\n",
|
530 |
"for epoch in range(EPOCHS):\n",
|
531 |
" for data, target in train_dataloader:\n",
|
532 |
" data = data.to('cuda')\n",
|
|
|
538 |
" optimizer.zero_grad()\n",
|
539 |
" loss.backward()\n",
|
540 |
" optimizer.step()\n",
|
541 |
+
" #scheduler.step()\n",
|
542 |
"\n",
|
543 |
+
" writer.add_scalar(\"Loss/train\", loss.cpu().detach().numpy(), global_step)\n",
|
544 |
+
" global_step += 1\n",
|
545 |
"\n",
|
546 |
+
" # TODO!!! WTF???\n",
|
547 |
" if step % VAL_INTERVAL == 0:\n",
|
548 |
+
" total_loss = 0\n",
|
549 |
+
" total_samples = 0\n",
|
550 |
+
"\n",
|
551 |
" with torch.no_grad():\n",
|
552 |
" model.eval()\n",
|
553 |
" for x, y in test_dataloader:\n",
|
|
|
560 |
" if total_samples > 10:\n",
|
561 |
" break\n",
|
562 |
"\n",
|
563 |
+
" model.train()\n",
|
564 |
+
" average_loss = total_loss / total_samples\n",
|
565 |
+
"\n",
|
566 |
+
" print(f\"Step {step}; loss: {average_loss}\")\n",
|
567 |
+
" writer.add_scalar(\"Loss/val\", average_loss, global_step)\n",
|
568 |
+
"\n",
|
569 |
"\n",
|
570 |
" step += 1\n",
|
571 |
" if step >= STEPS:\n",
|
572 |
+
" break\n",
|
573 |
+
"\n",
|
574 |
+
"writer.close()\n"
|
575 |
]
|
576 |
},
|
577 |
{
|
578 |
"cell_type": "code",
|
579 |
+
"execution_count": 69,
|
580 |
"metadata": {},
|
581 |
"outputs": [],
|
582 |
"source": [
|
|
|
585 |
},
|
586 |
{
|
587 |
"cell_type": "code",
|
588 |
+
"execution_count": 70,
|
589 |
"metadata": {},
|
590 |
"outputs": [],
|
591 |
"source": [
|
|
|
619 |
},
|
620 |
{
|
621 |
"cell_type": "code",
|
622 |
+
"execution_count": 26,
|
623 |
"metadata": {},
|
624 |
"outputs": [
|
625 |
{
|
626 |
"data": {
|
627 |
"text/plain": [
|
628 |
+
"841"
|
629 |
]
|
630 |
},
|
631 |
+
"execution_count": 26,
|
632 |
"metadata": {},
|
633 |
"output_type": "execute_result"
|
634 |
}
|
|
|
640 |
},
|
641 |
{
|
642 |
"cell_type": "code",
|
643 |
+
"execution_count": 57,
|
644 |
"metadata": {},
|
645 |
"outputs": [
|
646 |
{
|
647 |
"name": "stdout",
|
648 |
"output_type": "stream",
|
649 |
"text": [
|
650 |
+
"1.7774584962397206\n"
|
651 |
]
|
652 |
}
|
653 |
],
|
|
|
656 |
"total_loss = 0.0\n",
|
657 |
"total_samples = 0\n",
|
658 |
"\n",
|
659 |
+
"val_dataset = TextDataset(val_data, BLOCK_SIZE)\n",
|
660 |
+
"val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
|
661 |
"with torch.no_grad():\n",
|
662 |
+
" for x, y in val_dataloader:\n",
|
663 |
" x = x.to(\"cuda\")\n",
|
664 |
" y = y.to(\"cuda\")\n",
|
665 |
"\n",
|
666 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
667 |
" total_loss += batch_loss.item() * x.size(0)\n",
|
668 |
" total_samples += x.size(0)\n",
|
669 |
+
" if total_samples > 100000:\n",
|
670 |
" break\n",
|
671 |
"\n",
|
672 |
" average_loss = total_loss / total_samples\n",
|
|
|
675 |
},
|
676 |
{
|
677 |
"cell_type": "code",
|
678 |
+
"execution_count": 71,
|
679 |
"metadata": {},
|
680 |
+
"outputs": [
|
681 |
+
{
|
682 |
+
"data": {
|
683 |
+
"text/plain": [
|
684 |
+
"3286528"
|
685 |
+
]
|
686 |
+
},
|
687 |
+
"execution_count": 71,
|
688 |
+
"metadata": {},
|
689 |
+
"output_type": "execute_result"
|
690 |
+
}
|
691 |
+
],
|
692 |
+
"source": [
|
693 |
+
"num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
694 |
+
"num_params"
|
695 |
+
]
|
696 |
},
|
697 |
{
|
698 |
"attachments": {},
|
699 |
"cell_type": "markdown",
|
700 |
"metadata": {},
|
701 |
"source": [
|
702 |
+
"Finally, we generate. NOTE: seeds shorter than 256 chars have nonsense until you reach the context window. I think it's because Karpathy jammed the whole Shakespeare into one file with no act/scene breaks and both he and I didn't split it, so there's only one padding that the model sees, ever. TODO: fix this in the data loading step"
|
703 |
]
|
704 |
},
|
705 |
{
|
706 |
"cell_type": "code",
|
707 |
+
"execution_count": 58,
|
708 |
"metadata": {},
|
709 |
"outputs": [
|
710 |
{
|
711 |
"name": "stdout",
|
712 |
"output_type": "stream",
|
713 |
"text": [
|
714 |
+
"Tutus, to Marcius, noble Marcius\n",
|
715 |
+
"Made to my voices! doing and hangs upon them!\n",
|
716 |
+
"Take it to down our foes and hates with stain,\n",
|
717 |
+
"Which thus follows slay with on I meland,\n",
|
718 |
+
"What I am after her to her fearful haunt it?\n",
|
719 |
+
"\n",
|
720 |
+
"PAULINA:\n",
|
721 |
+
"But you are well to hold the king.\n",
|
722 |
+
"\n",
|
723 |
+
"ISABELLA:\n",
|
724 |
+
"And I will not go royalty to thy hand.\n",
|
725 |
+
"\n",
|
726 |
+
"LUCIO:\n",
|
727 |
+
"Since I do not well in such goodly talk of.\n",
|
728 |
+
"I think I have a stay of it!\n",
|
729 |
+
"\n",
|
730 |
+
"HENRY BOLINGBROKE:\n",
|
731 |
+
"Who say I hate been a day's mind;\n",
|
732 |
+
"Till we here and so very little and way,\n",
|
733 |
+
"And wash the city has nest seen the feast.\n",
|
734 |
+
"\n",
|
735 |
+
"DUCHESS OF YORK:\n",
|
736 |
+
"No, by the matter.\n",
|
737 |
+
"\n",
|
738 |
+
"ISABELLA:\n",
|
739 |
+
"Flitter than desire never yet looks so.\n",
|
740 |
+
"\n",
|
741 |
+
"HENRY BOLINGBROKE:\n",
|
742 |
+
"I am not possible perceived\n",
|
743 |
+
"And both place, where I may not rafes,\n",
|
744 |
+
"And like me one air. What you'll your love day?\n",
|
745 |
+
"\n",
|
746 |
+
"KING RICHARD II:\n",
|
747 |
+
"Then be thou--\n",
|
748 |
+
"\n",
|
749 |
+
"GLOUCESTER:\n",
|
750 |
+
"No, Lord Hastings:\n",
|
751 |
+
"Else queen, though my trowbers grands me to-morrow\n",
|
752 |
+
"Here to Bolingbroke's match;\n",
|
753 |
+
"When the your life and spur at homely speak.\n",
|
754 |
+
"\n",
|
755 |
+
"BUCKINGHAM:\n",
|
756 |
+
"My father was I follow: if you be your your kingdom,\n",
|
757 |
+
"My approbations an"
|
758 |
]
|
759 |
}
|
760 |
],
|
761 |
"source": [
|
762 |
"g_cuda = torch.Generator(device='cuda')\n",
|
763 |
"\n",
|
764 |
+
"seed = \"\"\"\n",
|
765 |
+
"Plot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\n",
|
766 |
+
"\"\"\"\n",
|
767 |
+
"\n",
|
768 |
+
"contexts = torch.tensor(encode_text(seed), dtype=torch.int32).to('cuda')\n",
|
769 |
+
"GEN_LENGTH=1024\n",
|
770 |
"\n",
|
771 |
"model.eval()\n",
|
772 |
"for i in range(GEN_LENGTH):\n",
|
|
|
774 |
" # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
|
775 |
" #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
|
776 |
" x = contexts[-BLOCK_SIZE:]\n",
|
777 |
+
" if x.size(0) < BLOCK_SIZE:\n",
|
778 |
+
" x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), \"constant\", 0).unsqueeze(0) # B*T\n",
|
779 |
+
" else:\n",
|
780 |
+
" x = x.unsqueeze(0)\n",
|
781 |
+
"\n",
|
782 |
" preds = model.infer(x)\n",
|
783 |
" preds = preds.squeeze(0)\n",
|
784 |
" probs = torch.softmax(preds, dim=-1)\n",
|
785 |
"\n",
|
786 |
" # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
|
787 |
+
" # next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
|
788 |
+
" next_char = torch.multinomial(torch.exp(preds[-1, :]), num_samples=1, generator=g_cuda)\n",
|
789 |
+
"\n",
|
790 |
" #context = torch.cat(context, next_char)\n",
|
791 |
" contexts = torch.cat((contexts, next_char), dim=0)\n",
|
792 |
" print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
|