WIP Parameterize gpt script, add CLI
Browse files- .gitignore +1 -0
- gpt.ipynb +460 -166
- gpt.py +502 -0
- gpt_config.json +17 -0
.gitignore
CHANGED
@@ -4,6 +4,7 @@ datasets/
|
|
4 |
|
5 |
# Training Tensorboard runs
|
6 |
runs/
|
|
|
7 |
|
8 |
# Byte-compiled / optimized / DLL files
|
9 |
__pycache__/
|
|
|
4 |
|
5 |
# Training Tensorboard runs
|
6 |
runs/
|
7 |
+
flagged/
|
8 |
|
9 |
# Byte-compiled / optimized / DLL files
|
10 |
__pycache__/
|
gpt.ipynb
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
-
"execution_count":
|
14 |
"metadata": {},
|
15 |
"outputs": [],
|
16 |
"source": [
|
@@ -22,7 +22,7 @@
|
|
22 |
},
|
23 |
{
|
24 |
"cell_type": "code",
|
25 |
-
"execution_count":
|
26 |
"metadata": {},
|
27 |
"outputs": [],
|
28 |
"source": [
|
@@ -32,7 +32,7 @@
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
-
"execution_count":
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
@@ -55,90 +55,18 @@
|
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
-
"execution_count":
|
59 |
"metadata": {},
|
60 |
-
"outputs": [
|
61 |
-
{
|
62 |
-
"name": "stdout",
|
63 |
-
"output_type": "stream",
|
64 |
-
"text": [
|
65 |
-
"Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (2.0.0)\n",
|
66 |
-
"Requirement already satisfied: pandas in ./venv/lib/python3.10/site-packages (1.5.3)\n",
|
67 |
-
"Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (1.24.1)\n",
|
68 |
-
"Requirement already satisfied: tensorboard in ./venv/lib/python3.10/site-packages (2.12.0)\n",
|
69 |
-
"Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in ./venv/lib/python3.10/site-packages (from torch) (2.14.3)\n",
|
70 |
-
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in ./venv/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
|
71 |
-
"Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in ./venv/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
|
72 |
-
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
|
73 |
-
"Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch) (3.0)\n",
|
74 |
-
"Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in ./venv/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
|
75 |
-
"Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from torch) (3.10.4)\n",
|
76 |
-
"Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.91)\n",
|
77 |
-
"Requirement already satisfied: typing-extensions in ./venv/lib/python3.10/site-packages (from torch) (4.5.0)\n",
|
78 |
-
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in ./venv/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
|
79 |
-
"Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch) (1.11.1)\n",
|
80 |
-
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
|
81 |
-
"Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in ./venv/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
|
82 |
-
"Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch) (3.1.2)\n",
|
83 |
-
"Requirement already satisfied: triton==2.0.0 in ./venv/lib/python3.10/site-packages (from torch) (2.0.0)\n",
|
84 |
-
"Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in ./venv/lib/python3.10/site-packages (from torch) (11.7.101)\n",
|
85 |
-
"Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
|
86 |
-
"Requirement already satisfied: wheel in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.40.0)\n",
|
87 |
-
"Requirement already satisfied: setuptools in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (65.5.0)\n",
|
88 |
-
"Requirement already satisfied: lit in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.0)\n",
|
89 |
-
"Requirement already satisfied: cmake in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.26.1)\n",
|
90 |
-
"Requirement already satisfied: python-dateutil>=2.8.1 in ./venv/lib/python3.10/site-packages (from pandas) (2.8.2)\n",
|
91 |
-
"Requirement already satisfied: pytz>=2020.1 in ./venv/lib/python3.10/site-packages (from pandas) (2023.2)\n",
|
92 |
-
"Requirement already satisfied: requests<3,>=2.21.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.28.2)\n",
|
93 |
-
"Requirement already satisfied: werkzeug>=1.0.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.2.3)\n",
|
94 |
-
"Requirement already satisfied: google-auth<3,>=1.6.3 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.16.3)\n",
|
95 |
-
"Requirement already satisfied: protobuf>=3.19.6 in ./venv/lib/python3.10/site-packages (from tensorboard) (4.22.1)\n",
|
96 |
-
"Requirement already satisfied: markdown>=2.6.8 in ./venv/lib/python3.10/site-packages (from tensorboard) (3.4.3)\n",
|
97 |
-
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.4.6)\n",
|
98 |
-
"Requirement already satisfied: grpcio>=1.48.2 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.51.3)\n",
|
99 |
-
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.7.0)\n",
|
100 |
-
"Requirement already satisfied: absl-py>=0.4 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.4.0)\n",
|
101 |
-
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.8.1)\n",
|
102 |
-
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (5.3.0)\n",
|
103 |
-
"Requirement already satisfied: pyasn1-modules>=0.2.1 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (0.2.8)\n",
|
104 |
-
"Requirement already satisfied: six>=1.9.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (1.16.0)\n",
|
105 |
-
"Requirement already satisfied: rsa<5,>=3.1.4 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (4.9)\n",
|
106 |
-
"Requirement already satisfied: requests-oauthlib>=0.7.0 in ./venv/lib/python3.10/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (1.3.1)\n",
|
107 |
-
"Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.4)\n",
|
108 |
-
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (1.26.15)\n",
|
109 |
-
"Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.1.0)\n",
|
110 |
-
"Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (2022.12.7)\n",
|
111 |
-
"Requirement already satisfied: MarkupSafe>=2.1.1 in ./venv/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.2)\n",
|
112 |
-
"Requirement already satisfied: mpmath>=0.19 in ./venv/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
|
113 |
-
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in ./venv/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.4.8)\n",
|
114 |
-
"Requirement already satisfied: oauthlib>=3.0.0 in ./venv/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (3.2.2)\n",
|
115 |
-
"\n",
|
116 |
-
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n",
|
117 |
-
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
118 |
-
"Note: you may need to restart the kernel to use updated packages.\n"
|
119 |
-
]
|
120 |
-
}
|
121 |
-
],
|
122 |
"source": [
|
123 |
-
"
|
124 |
]
|
125 |
},
|
126 |
{
|
127 |
"cell_type": "code",
|
128 |
-
"execution_count":
|
129 |
"metadata": {},
|
130 |
-
"outputs": [
|
131 |
-
{
|
132 |
-
"data": {
|
133 |
-
"text/plain": [
|
134 |
-
"<torch._C.Generator at 0x7fef50768610>"
|
135 |
-
]
|
136 |
-
},
|
137 |
-
"execution_count": 43,
|
138 |
-
"metadata": {},
|
139 |
-
"output_type": "execute_result"
|
140 |
-
}
|
141 |
-
],
|
142 |
"source": [
|
143 |
"import torch\n",
|
144 |
"import torch.nn as nn\n",
|
@@ -150,12 +78,14 @@
|
|
150 |
"import numpy as np\n",
|
151 |
"import math\n",
|
152 |
"\n",
|
153 |
-
"torch.manual_seed(1337)"
|
|
|
|
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
-
"execution_count":
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": [
|
@@ -169,7 +99,7 @@
|
|
169 |
},
|
170 |
{
|
171 |
"cell_type": "code",
|
172 |
-
"execution_count":
|
173 |
"metadata": {},
|
174 |
"outputs": [
|
175 |
{
|
@@ -202,7 +132,7 @@
|
|
202 |
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
-
"execution_count":
|
206 |
"metadata": {},
|
207 |
"outputs": [],
|
208 |
"source": [
|
@@ -231,7 +161,7 @@
|
|
231 |
},
|
232 |
{
|
233 |
"cell_type": "code",
|
234 |
-
"execution_count":
|
235 |
"metadata": {},
|
236 |
"outputs": [],
|
237 |
"source": [
|
@@ -298,7 +228,7 @@
|
|
298 |
},
|
299 |
{
|
300 |
"cell_type": "code",
|
301 |
-
"execution_count":
|
302 |
"metadata": {},
|
303 |
"outputs": [],
|
304 |
"source": [
|
@@ -318,7 +248,7 @@
|
|
318 |
},
|
319 |
{
|
320 |
"cell_type": "code",
|
321 |
-
"execution_count":
|
322 |
"metadata": {},
|
323 |
"outputs": [],
|
324 |
"source": [
|
@@ -344,7 +274,7 @@
|
|
344 |
},
|
345 |
{
|
346 |
"cell_type": "code",
|
347 |
-
"execution_count":
|
348 |
"metadata": {},
|
349 |
"outputs": [],
|
350 |
"source": [
|
@@ -404,7 +334,7 @@
|
|
404 |
},
|
405 |
{
|
406 |
"cell_type": "code",
|
407 |
-
"execution_count":
|
408 |
"metadata": {},
|
409 |
"outputs": [],
|
410 |
"source": [
|
@@ -419,7 +349,7 @@
|
|
419 |
},
|
420 |
{
|
421 |
"cell_type": "code",
|
422 |
-
"execution_count":
|
423 |
"metadata": {},
|
424 |
"outputs": [],
|
425 |
"source": [
|
@@ -435,7 +365,7 @@
|
|
435 |
" context_size=BLOCK_SIZE,\n",
|
436 |
" )\n",
|
437 |
"\n",
|
438 |
-
"model = model.to(
|
439 |
"optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
|
440 |
"#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
|
441 |
"criterion = F.cross_entropy\n",
|
@@ -529,8 +459,8 @@
|
|
529 |
"\n",
|
530 |
"for epoch in range(EPOCHS):\n",
|
531 |
" for data, target in train_dataloader:\n",
|
532 |
-
" data = data.to(
|
533 |
-
" target = target.to(
|
534 |
"\n",
|
535 |
" loss = compute_loss(model, criterion, data, target)\n",
|
536 |
"\n",
|
@@ -551,8 +481,8 @@
|
|
551 |
" with torch.no_grad():\n",
|
552 |
" model.eval()\n",
|
553 |
" for x, y in test_dataloader:\n",
|
554 |
-
" x = x.to(
|
555 |
-
" y = y.to(
|
556 |
"\n",
|
557 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
558 |
" total_loss += batch_loss.item() * 512\n",
|
@@ -576,7 +506,7 @@
|
|
576 |
},
|
577 |
{
|
578 |
"cell_type": "code",
|
579 |
-
"execution_count":
|
580 |
"metadata": {},
|
581 |
"outputs": [],
|
582 |
"source": [
|
@@ -600,7 +530,7 @@
|
|
600 |
},
|
601 |
{
|
602 |
"cell_type": "code",
|
603 |
-
"execution_count":
|
604 |
"metadata": {},
|
605 |
"outputs": [],
|
606 |
"source": [
|
@@ -619,16 +549,16 @@
|
|
619 |
},
|
620 |
{
|
621 |
"cell_type": "code",
|
622 |
-
"execution_count":
|
623 |
"metadata": {},
|
624 |
"outputs": [
|
625 |
{
|
626 |
"data": {
|
627 |
"text/plain": [
|
628 |
-
"
|
629 |
]
|
630 |
},
|
631 |
-
"execution_count":
|
632 |
"metadata": {},
|
633 |
"output_type": "execute_result"
|
634 |
}
|
@@ -660,8 +590,8 @@
|
|
660 |
"val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
|
661 |
"with torch.no_grad():\n",
|
662 |
" for x, y in val_dataloader:\n",
|
663 |
-
" x = x.to(
|
664 |
-
" y = y.to(
|
665 |
"\n",
|
666 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
667 |
" total_loss += batch_loss.item() * x.size(0)\n",
|
@@ -704,94 +634,458 @@
|
|
704 |
},
|
705 |
{
|
706 |
"cell_type": "code",
|
707 |
-
"execution_count":
|
708 |
"metadata": {},
|
709 |
"outputs": [
|
710 |
{
|
711 |
"name": "stdout",
|
712 |
"output_type": "stream",
|
713 |
"text": [
|
714 |
-
"Tutus, to Marcius, noble Marcius\n",
|
715 |
-
"Made to my voices! doing and hangs upon them!\n",
|
716 |
-
"Take it to down our foes and hates with stain,\n",
|
717 |
-
"Which thus follows slay with on I meland,\n",
|
718 |
-
"What I am after her to her fearful haunt it?\n",
|
719 |
"\n",
|
720 |
-
"
|
721 |
-
"
|
|
|
722 |
"\n",
|
723 |
-
"
|
724 |
-
"
|
|
|
|
|
725 |
"\n",
|
726 |
-
"
|
727 |
-
"
|
728 |
-
"I think I have a stay of it!\n",
|
729 |
"\n",
|
730 |
-
"
|
731 |
-
"
|
732 |
-
"
|
733 |
-
"
|
|
|
|
|
734 |
"\n",
|
735 |
-
"
|
736 |
-
"
|
737 |
"\n",
|
738 |
-
"
|
739 |
-
"
|
|
|
|
|
|
|
740 |
"\n",
|
741 |
-
"
|
742 |
-
"
|
743 |
-
"
|
744 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
"\n",
|
746 |
-
"
|
747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
"\n",
|
749 |
"GLOUCESTER:\n",
|
750 |
-
"
|
751 |
-
"
|
752 |
-
"
|
753 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
754 |
"\n",
|
755 |
"BUCKINGHAM:\n",
|
756 |
-
"
|
757 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
]
|
759 |
}
|
760 |
],
|
761 |
"source": [
|
762 |
-
"
|
763 |
"\n",
|
764 |
-
"
|
765 |
-
"
|
766 |
-
"\
|
767 |
-
"
|
768 |
-
"
|
769 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
"\n",
|
771 |
-
"
|
772 |
-
"for i in range(GEN_LENGTH):\n",
|
773 |
-
" transform = nn.LogSoftmax(1)\n",
|
774 |
-
" # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
|
775 |
-
" #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
|
776 |
-
" x = contexts[-BLOCK_SIZE:]\n",
|
777 |
-
" if x.size(0) < BLOCK_SIZE:\n",
|
778 |
-
" x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), \"constant\", 0).unsqueeze(0) # B*T\n",
|
779 |
-
" else:\n",
|
780 |
-
" x = x.unsqueeze(0)\n",
|
781 |
-
"\n",
|
782 |
-
" preds = model.infer(x)\n",
|
783 |
-
" preds = preds.squeeze(0)\n",
|
784 |
-
" probs = torch.softmax(preds, dim=-1)\n",
|
785 |
-
"\n",
|
786 |
-
" # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
|
787 |
-
" # next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
|
788 |
-
" next_char = torch.multinomial(torch.exp(preds[-1, :]), num_samples=1, generator=g_cuda)\n",
|
789 |
-
"\n",
|
790 |
-
" #context = torch.cat(context, next_char)\n",
|
791 |
-
" contexts = torch.cat((contexts, next_char), dim=0)\n",
|
792 |
-
" print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
|
793 |
-
"\n",
|
794 |
-
"#print(\"\".join(decode_text(contexts.cpu().numpy())))"
|
795 |
]
|
796 |
}
|
797 |
],
|
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": 2,
|
14 |
"metadata": {},
|
15 |
"outputs": [],
|
16 |
"source": [
|
|
|
22 |
},
|
23 |
{
|
24 |
"cell_type": "code",
|
25 |
+
"execution_count": 3,
|
26 |
"metadata": {},
|
27 |
"outputs": [],
|
28 |
"source": [
|
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
+
"execution_count": 4,
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
|
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
+
"execution_count": 5,
|
59 |
"metadata": {},
|
60 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
"source": [
|
62 |
+
"#%pip install torch pandas numpy tensorboard gradio"
|
63 |
]
|
64 |
},
|
65 |
{
|
66 |
"cell_type": "code",
|
67 |
+
"execution_count": 6,
|
68 |
"metadata": {},
|
69 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"source": [
|
71 |
"import torch\n",
|
72 |
"import torch.nn as nn\n",
|
|
|
78 |
"import numpy as np\n",
|
79 |
"import math\n",
|
80 |
"\n",
|
81 |
+
"torch.manual_seed(1337)\n",
|
82 |
+
"# Set device to CUDA if available, otherwise use CPU\n",
|
83 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
84 |
]
|
85 |
},
|
86 |
{
|
87 |
"cell_type": "code",
|
88 |
+
"execution_count": 7,
|
89 |
"metadata": {},
|
90 |
"outputs": [],
|
91 |
"source": [
|
|
|
99 |
},
|
100 |
{
|
101 |
"cell_type": "code",
|
102 |
+
"execution_count": 8,
|
103 |
"metadata": {},
|
104 |
"outputs": [
|
105 |
{
|
|
|
132 |
},
|
133 |
{
|
134 |
"cell_type": "code",
|
135 |
+
"execution_count": 9,
|
136 |
"metadata": {},
|
137 |
"outputs": [],
|
138 |
"source": [
|
|
|
161 |
},
|
162 |
{
|
163 |
"cell_type": "code",
|
164 |
+
"execution_count": 10,
|
165 |
"metadata": {},
|
166 |
"outputs": [],
|
167 |
"source": [
|
|
|
228 |
},
|
229 |
{
|
230 |
"cell_type": "code",
|
231 |
+
"execution_count": 11,
|
232 |
"metadata": {},
|
233 |
"outputs": [],
|
234 |
"source": [
|
|
|
248 |
},
|
249 |
{
|
250 |
"cell_type": "code",
|
251 |
+
"execution_count": 12,
|
252 |
"metadata": {},
|
253 |
"outputs": [],
|
254 |
"source": [
|
|
|
274 |
},
|
275 |
{
|
276 |
"cell_type": "code",
|
277 |
+
"execution_count": 13,
|
278 |
"metadata": {},
|
279 |
"outputs": [],
|
280 |
"source": [
|
|
|
334 |
},
|
335 |
{
|
336 |
"cell_type": "code",
|
337 |
+
"execution_count": 14,
|
338 |
"metadata": {},
|
339 |
"outputs": [],
|
340 |
"source": [
|
|
|
349 |
},
|
350 |
{
|
351 |
"cell_type": "code",
|
352 |
+
"execution_count": 15,
|
353 |
"metadata": {},
|
354 |
"outputs": [],
|
355 |
"source": [
|
|
|
365 |
" context_size=BLOCK_SIZE,\n",
|
366 |
" )\n",
|
367 |
"\n",
|
368 |
+
"model = model.to(device)\n",
|
369 |
"optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
|
370 |
"#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
|
371 |
"criterion = F.cross_entropy\n",
|
|
|
459 |
"\n",
|
460 |
"for epoch in range(EPOCHS):\n",
|
461 |
" for data, target in train_dataloader:\n",
|
462 |
+
" data = data.to(device)\n",
|
463 |
+
" target = target.to(device)\n",
|
464 |
"\n",
|
465 |
" loss = compute_loss(model, criterion, data, target)\n",
|
466 |
"\n",
|
|
|
481 |
" with torch.no_grad():\n",
|
482 |
" model.eval()\n",
|
483 |
" for x, y in test_dataloader:\n",
|
484 |
+
" x = x.to(device)\n",
|
485 |
+
" y = y.to(device)\n",
|
486 |
"\n",
|
487 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
488 |
" total_loss += batch_loss.item() * 512\n",
|
|
|
506 |
},
|
507 |
{
|
508 |
"cell_type": "code",
|
509 |
+
"execution_count": 16,
|
510 |
"metadata": {},
|
511 |
"outputs": [],
|
512 |
"source": [
|
|
|
530 |
},
|
531 |
{
|
532 |
"cell_type": "code",
|
533 |
+
"execution_count": 17,
|
534 |
"metadata": {},
|
535 |
"outputs": [],
|
536 |
"source": [
|
|
|
549 |
},
|
550 |
{
|
551 |
"cell_type": "code",
|
552 |
+
"execution_count": 21,
|
553 |
"metadata": {},
|
554 |
"outputs": [
|
555 |
{
|
556 |
"data": {
|
557 |
"text/plain": [
|
558 |
+
"3"
|
559 |
]
|
560 |
},
|
561 |
+
"execution_count": 21,
|
562 |
"metadata": {},
|
563 |
"output_type": "execute_result"
|
564 |
}
|
|
|
590 |
"val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
|
591 |
"with torch.no_grad():\n",
|
592 |
" for x, y in val_dataloader:\n",
|
593 |
+
" x = x.to(device)\n",
|
594 |
+
" y = y.to(device)\n",
|
595 |
"\n",
|
596 |
" batch_loss = compute_loss(model, criterion, x, y)\n",
|
597 |
" total_loss += batch_loss.item() * x.size(0)\n",
|
|
|
634 |
},
|
635 |
{
|
636 |
"cell_type": "code",
|
637 |
+
"execution_count": 54,
|
638 |
"metadata": {},
|
639 |
"outputs": [
|
640 |
{
|
641 |
"name": "stdout",
|
642 |
"output_type": "stream",
|
643 |
"text": [
|
|
|
|
|
|
|
|
|
|
|
644 |
"\n",
|
645 |
+
"Shepherd:\n",
|
646 |
+
"What should be gone to your deeds to prince\n",
|
647 |
+
"That let us away the house of Lancasters.\n",
|
648 |
"\n",
|
649 |
+
"Second Murderer:\n",
|
650 |
+
"Alas, if I be not the fire rogued man\n",
|
651 |
+
"That hath dead by the fault of the good stars;\n",
|
652 |
+
"And the business of some many that be sounded.\n",
|
653 |
"\n",
|
654 |
+
"AUTOLYCUS:\n",
|
655 |
+
"My soul are but not so dishonours.\n",
|
|
|
656 |
"\n",
|
657 |
+
"BRUTUS:\n",
|
658 |
+
"I will do not be point, but it is a soul,\n",
|
659 |
+
"I would I see my heart fair state,\n",
|
660 |
+
"That he is in the grief and be contented\n",
|
661 |
+
"That you may made the father, the day doth nothing\n",
|
662 |
+
"And lie at like in his world moved me at him.\n",
|
663 |
"\n",
|
664 |
+
"First Senator:\n",
|
665 |
+
"The rotten far the points of our honour.\n",
|
666 |
"\n",
|
667 |
+
"AUTOLYCUS:\n",
|
668 |
+
"The king hath been you think the duke of the foes,\n",
|
669 |
+
"And had so long and confessal abroad,\n",
|
670 |
+
"And said in the light save of men so meeting\n",
|
671 |
+
"May not shall do be as we as here.\n",
|
672 |
"\n",
|
673 |
+
"Second Servingman:\n",
|
674 |
+
"You are grace as the child and fight again,\n",
|
675 |
+
"If you do not shall not be much of the thing;\n",
|
676 |
+
"One but so the creature of the charge is were\n",
|
677 |
+
"Than the point of your several motion\n",
|
678 |
+
"May seem as her like a pride and as would all grave\n",
|
679 |
+
"The which of the blood"
|
680 |
+
]
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"data": {
|
684 |
+
"text/plain": [
|
685 |
+
"'\\nPlot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\\n\\nShepherd:\\nWhat should be gone to your deeds to prince\\nThat let us away the house of Lancasters.\\n\\nSecond Murderer:\\nAlas, if I be not the fire rogued man\\nThat hath dead by the fault of the good stars;\\nAnd the business of some many that be sounded.\\n\\nAUTOLYCUS:\\nMy soul are but not so dishonours.\\n\\nBRUTUS:\\nI will do not be point, but it is a soul,\\nI would I see my heart fair state,\\nThat he is in the grief and be contented\\nThat you may made the father, the day doth nothing\\nAnd lie at like in his world moved me at him.\\n\\nFirst Senator:\\nThe rotten far the points of our honour.\\n\\nAUTOLYCUS:\\nThe king hath been you think the duke of the foes,\\nAnd had so long and confessal abroad,\\nAnd said in the light save of men so meeting\\nMay not shall do be as we as here.\\n\\nSecond Servingman:\\nYou are grace as the child and fight again,\\nIf you do not shall not be much of the thing;\\nOne but so the creature of the charge is were\\nThan the point of your several motion\\nMay seem as her like a pride and as would all grave\\nThe which of the blood'"
|
686 |
+
]
|
687 |
+
},
|
688 |
+
"execution_count": 54,
|
689 |
+
"metadata": {},
|
690 |
+
"output_type": "execute_result"
|
691 |
+
}
|
692 |
+
],
|
693 |
+
"source": [
|
694 |
+
"seed = \"\"\"\n",
|
695 |
+
"Plot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\n",
|
696 |
+
"\"\"\"\n",
|
697 |
+
"GEN_LENGTH=1024\n",
|
698 |
+
"\n",
|
699 |
+
"def generate(prompt, gen_length, temp=1, top_k=10, top_p=None):\n",
|
700 |
+
" g_cuda = torch.Generator(device=device)\n",
|
701 |
+
" contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device)\n",
|
702 |
+
"\n",
|
703 |
+
" model.eval()\n",
|
704 |
+
" for i in range(GEN_LENGTH):\n",
|
705 |
+
" transform = nn.LogSoftmax(1)\n",
|
706 |
+
" x = contexts[-BLOCK_SIZE:]\n",
|
707 |
+
" if x.size(0) < BLOCK_SIZE:\n",
|
708 |
+
" x = F.pad(x, (BLOCK_SIZE - x.size(0), 0), \"constant\", 0).unsqueeze(0) # B*T\n",
|
709 |
+
" else:\n",
|
710 |
+
" x = x.unsqueeze(0)\n",
|
711 |
+
"\n",
|
712 |
+
" preds = model.infer(x)\n",
|
713 |
+
" preds = preds.squeeze(0)\n",
|
714 |
+
" preds = preds / temp\n",
|
715 |
+
" probs = F.softmax(preds, dim=-1)\n",
|
716 |
+
"\n",
|
717 |
+
" if top_p is not None:\n",
|
718 |
+
" # Apply top-p\n",
|
719 |
+
" sorted_probs, sorted_indices = torch.sort(probs[-1, :], descending=True)\n",
|
720 |
+
" cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n",
|
721 |
+
" # find cutoff\n",
|
722 |
+
" idx_top_p = (cumulative_probs < top_p).sum().item()\n",
|
723 |
+
" top_probs = sorted_probs[:idx_top_p]\n",
|
724 |
+
" top_indices = sorted_indices[:idx_top_p]\n",
|
725 |
+
" # Null case\n",
|
726 |
+
" if top_probs.size(0) == 0:\n",
|
727 |
+
" top_probs = sorted_probs[:1]\n",
|
728 |
+
" top_indices = sorted_indices[:1]\n",
|
729 |
+
" \n",
|
730 |
+
" next_char = torch.multinomial(top_probs, num_samples=1, generator=g_cuda)\n",
|
731 |
+
" next_char = top_indices[next_char]\n",
|
732 |
+
" elif top_k is not None:\n",
|
733 |
+
" top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k)\n",
|
734 |
+
" next_char = torch.multinomial(top_k_probs, num_samples=1, generator=g_cuda)\n",
|
735 |
+
" next_char = top_k_indices[next_char]\n",
|
736 |
+
" else:\n",
|
737 |
+
" next_char = torch.multinomial(probs, num_samples=1, generator=g_cuda)\n",
|
738 |
+
"\n",
|
739 |
+
"\n",
|
740 |
+
" contexts = torch.cat((contexts, next_char), dim=0)\n",
|
741 |
+
" print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
|
742 |
+
" \n",
|
743 |
+
" return(\"\".join(decode_text(contexts.cpu().numpy())))\n",
|
744 |
+
"\n",
|
745 |
+
"\"\".join(generate(seed, GEN_LENGTH,temp=0.8,top_p=0.9))"
|
746 |
+
]
|
747 |
+
},
|
748 |
+
{
|
749 |
+
"attachments": {},
|
750 |
+
"cell_type": "markdown",
|
751 |
+
"metadata": {},
|
752 |
+
"source": [
|
753 |
+
"## Gradio WebUI"
|
754 |
+
]
|
755 |
+
},
|
756 |
+
{
|
757 |
+
"cell_type": "code",
|
758 |
+
"execution_count": 55,
|
759 |
+
"metadata": {},
|
760 |
+
"outputs": [
|
761 |
+
{
|
762 |
+
"name": "stdout",
|
763 |
+
"output_type": "stream",
|
764 |
+
"text": [
|
765 |
+
"Running on local URL: http://127.0.0.1:7866\n",
|
766 |
"\n",
|
767 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
768 |
+
]
|
769 |
+
},
|
770 |
+
{
|
771 |
+
"data": {
|
772 |
+
"text/html": [
|
773 |
+
"<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
774 |
+
],
|
775 |
+
"text/plain": [
|
776 |
+
"<IPython.core.display.HTML object>"
|
777 |
+
]
|
778 |
+
},
|
779 |
+
"metadata": {},
|
780 |
+
"output_type": "display_data"
|
781 |
+
},
|
782 |
+
{
|
783 |
+
"data": {
|
784 |
+
"text/plain": []
|
785 |
+
},
|
786 |
+
"execution_count": 55,
|
787 |
+
"metadata": {},
|
788 |
+
"output_type": "execute_result"
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"name": "stdout",
|
792 |
+
"output_type": "stream",
|
793 |
+
"text": [
|
794 |
+
"\n",
|
795 |
+
"If you do I am sure be your father's limbs,\n",
|
796 |
+
"Which is not drouble Paulina,\n",
|
797 |
+
"Where I come to me too much be to thus\n",
|
798 |
+
"As in a mind sorten but here hoped--followed as hung\n",
|
799 |
+
"bard at the time own that that he is death;\n",
|
800 |
+
"The stir weors shall be a new good groans.\n",
|
801 |
+
"\n",
|
802 |
+
"OXFORD:\n",
|
803 |
+
"And I say him to deed your city,\n",
|
804 |
+
"Thy heart advantage by mine at their world,\n",
|
805 |
+
"And be glood much show with upon the suddenly,\n",
|
806 |
+
"And in this free can company from here faults.\n",
|
807 |
+
"\n",
|
808 |
+
"YORK:\n",
|
809 |
+
"I bring thee, now I am again: it is a Montague.\n",
|
810 |
+
"\n",
|
811 |
+
"BRUTUS:\n",
|
812 |
+
"We while you not be the fault of the life.\n",
|
813 |
+
"\n",
|
814 |
+
"KING RICHARD III:\n",
|
815 |
+
"Come, sir, do I: in put I with so,\n",
|
816 |
+
"The rest redies with this ballad or suppher.\n",
|
817 |
+
"\n",
|
818 |
+
"Third Servingman:\n",
|
819 |
+
"Now, well do with kindness of power the that\n",
|
820 |
+
"ence of the ambassage; but to shall here,\n",
|
821 |
+
"But saving as the wounds both drawn dead.\n",
|
822 |
+
"\n",
|
823 |
+
"CLARENCE:\n",
|
824 |
+
"Peace, sir, he adieu.\n",
|
825 |
+
"\n",
|
826 |
+
"MISTRESS OVERDONE:\n",
|
827 |
+
"Cheep, comes thou lovest which I do well;\n",
|
828 |
+
"Now now as that protectors, in this parliament--\n",
|
829 |
+
"That nature i' the mark of her with old her,\n",
|
830 |
+
"And so here doth in circle and call'd the wind,\n",
|
831 |
+
"If you do I am sure be your father's limbs,\n",
|
832 |
+
"Which is not drouble Paulina,\n",
|
833 |
+
"Where I come to me too much be to thus\n",
|
834 |
+
"As in a mind sorten but here hoped--followed as hung\n",
|
835 |
+
"bard at the time own that that he is death;\n",
|
836 |
+
"The stir weors shall be a new good groans.\n",
|
837 |
+
"\n",
|
838 |
+
"OXFORD:\n",
|
839 |
+
"And I say him to deed your city,\n",
|
840 |
+
"Thy heart advantage by mine at their world,\n",
|
841 |
+
"And be glood much show with upon the suddenly,\n",
|
842 |
+
"And in this free can company from here faults.\n",
|
843 |
+
"\n",
|
844 |
+
"YORK:\n",
|
845 |
+
"I bring thee, now I am again: it is a Montague.\n",
|
846 |
+
"\n",
|
847 |
+
"BRUTUS:\n",
|
848 |
+
"We while you not be the fault of the life.\n",
|
849 |
+
"\n",
|
850 |
+
"KING RICHARD III:\n",
|
851 |
+
"Come, sir, do I: in put I with so,\n",
|
852 |
+
"The rest redies with this ballad or suppher.\n",
|
853 |
+
"\n",
|
854 |
+
"Third Servingman:\n",
|
855 |
+
"Now, well do with kindness of power the that\n",
|
856 |
+
"ence of the ambassage; but to shall here,\n",
|
857 |
+
"But saving as the wounds both drawn dead.\n",
|
858 |
+
"\n",
|
859 |
+
"CLARENCE:\n",
|
860 |
+
"Peace, sir, he adieu.\n",
|
861 |
+
"\n",
|
862 |
+
"MISTRESS OVERDONE:\n",
|
863 |
+
"Cheep, comes thou lovest which I do well;\n",
|
864 |
+
"Now now as that protectors, in this parliament--\n",
|
865 |
+
"That nature i' the mark of her with old her,\n",
|
866 |
+
"And so here doth in circle and call'd the wind,\n",
|
867 |
+
"If you do I am sure be your father's limbs,\n",
|
868 |
+
"Which is not drouble Paulina,\n",
|
869 |
+
"Where I come to me too much be to thus\n",
|
870 |
+
"As in a mind sorten but here hoped--followed as hung\n",
|
871 |
+
"bard at the time own that that he is death;\n",
|
872 |
+
"The stir weors shall be a new good groans.\n",
|
873 |
+
"\n",
|
874 |
+
"OXFORD:\n",
|
875 |
+
"And I say him to deed your city,\n",
|
876 |
+
"Thy heart advantage by mine at their world,\n",
|
877 |
+
"And be glood much show with upon the suddenly,\n",
|
878 |
+
"And in this free can company from here faults.\n",
|
879 |
+
"\n",
|
880 |
+
"YORK:\n",
|
881 |
+
"I bring thee, now I am again: it is a Montague.\n",
|
882 |
+
"\n",
|
883 |
+
"BRUTUS:\n",
|
884 |
+
"We while you not be the fault of the life.\n",
|
885 |
+
"\n",
|
886 |
+
"KING RICHARD III:\n",
|
887 |
+
"Come, sir, do I: in put I with so,\n",
|
888 |
+
"The rest redies with this ballad or suppher.\n",
|
889 |
+
"\n",
|
890 |
+
"Third Servingman:\n",
|
891 |
+
"Now, well do with kindness of power the that\n",
|
892 |
+
"ence of the ambassage; but to shall here,\n",
|
893 |
+
"But saving as the wounds both drawn dead.\n",
|
894 |
+
"\n",
|
895 |
+
"CLARENCE:\n",
|
896 |
+
"Peace, sir, he adieu.\n",
|
897 |
+
"\n",
|
898 |
+
"MISTRESS OVERDONE:\n",
|
899 |
+
"Cheep, comes thou lovest which I do well;\n",
|
900 |
+
"Now now as that protectors, in this parliament--\n",
|
901 |
+
"That nature i' the mark of her with old her,\n",
|
902 |
+
"And so here doth in circle and call'd the wind,\n",
|
903 |
+
"If you do I am sure be your father's limbs,\n",
|
904 |
+
"Which is not drouble Paulina,\n",
|
905 |
+
"Where I come to me too much be to thus\n",
|
906 |
+
"As in a mind sorten but here hoped--followed as hung\n",
|
907 |
+
"bard at the time own that that he is death;\n",
|
908 |
+
"The stir weors shall be a new good groans.\n",
|
909 |
+
"\n",
|
910 |
+
"OXFORD:\n",
|
911 |
+
"And I say him to deed your city,\n",
|
912 |
+
"Thy heart advantage by mine at their world,\n",
|
913 |
+
"And be glood much show with upon the suddenly,\n",
|
914 |
+
"And in this free can company from here faults.\n",
|
915 |
+
"\n",
|
916 |
+
"YORK:\n",
|
917 |
+
"I bring thee, now I am again: it is a Montague.\n",
|
918 |
+
"\n",
|
919 |
+
"BRUTUS:\n",
|
920 |
+
"We while you not be the fault of the life.\n",
|
921 |
+
"\n",
|
922 |
+
"KING RICHARD III:\n",
|
923 |
+
"Come, sir, do I: in put I with so,\n",
|
924 |
+
"The rest redies with this ballad or suppher.\n",
|
925 |
+
"\n",
|
926 |
+
"Third Servingman:\n",
|
927 |
+
"Now, well do with kindness of power the that\n",
|
928 |
+
"ence of the ambassage; but to shall here,\n",
|
929 |
+
"But saving as the wounds both drawn dead.\n",
|
930 |
+
"\n",
|
931 |
+
"CLARENCE:\n",
|
932 |
+
"Peace, sir, he adieu.\n",
|
933 |
+
"\n",
|
934 |
+
"MISTRESS OVERDONE:\n",
|
935 |
+
"Cheep, comes thou lovest which I do well;\n",
|
936 |
+
"Now now as that protectors, in this parliament--\n",
|
937 |
+
"That nature i' the mark of her with old her,\n",
|
938 |
+
"And so here doth in circle and call'd the wind,\n",
|
939 |
+
"If you do I am sure be your father's limbs,\n",
|
940 |
+
"Which is not drouble Paulina,\n",
|
941 |
+
"Where I come to me too much be to thus\n",
|
942 |
+
"As in a mind sorten but here hoped--followed as hung\n",
|
943 |
+
"bard at the time own that that he is death;\n",
|
944 |
+
"The stir weors shall be a new good groans.\n",
|
945 |
+
"\n",
|
946 |
+
"OXFORD:\n",
|
947 |
+
"And I say him to deed your city,\n",
|
948 |
+
"Thy heart advantage by mine at their world,\n",
|
949 |
+
"And be glood much show with upon the suddenly,\n",
|
950 |
+
"And in this free can company from here faults.\n",
|
951 |
+
"\n",
|
952 |
+
"YORK:\n",
|
953 |
+
"I bring thee, now I am again: it is a Montague.\n",
|
954 |
+
"\n",
|
955 |
+
"BRUTUS:\n",
|
956 |
+
"We while you not be the fault of the life.\n",
|
957 |
+
"\n",
|
958 |
+
"KING RICHARD III:\n",
|
959 |
+
"Come, sir, do I: in put I with so,\n",
|
960 |
+
"The rest redies with this ballad or suppher.\n",
|
961 |
+
"\n",
|
962 |
+
"Third Servingman:\n",
|
963 |
+
"Now, well do with kindness of power the that\n",
|
964 |
+
"ence of the ambassage; but to shall here,\n",
|
965 |
+
"But saving as the wounds both drawn dead.\n",
|
966 |
+
"\n",
|
967 |
+
"CLARENCE:\n",
|
968 |
+
"Peace, sir, he adieu.\n",
|
969 |
+
"\n",
|
970 |
+
"MISTRESS OVERDONE:\n",
|
971 |
+
"Cheep, comes thou lovest which I do well;\n",
|
972 |
+
"Now now as that protectors, in this parliament--\n",
|
973 |
+
"That nature i' the mark of her with old her,\n",
|
974 |
+
"And so here doth in circle and call'd the wind,\n",
|
975 |
+
"If you do I have some that to be blow.\n",
|
976 |
+
"\n",
|
977 |
+
"QUEEN ELIZABETH:\n",
|
978 |
+
"But shall we will not a man; for all one,\n",
|
979 |
+
"That you will not be but a fault and house,\n",
|
980 |
+
"I cannot strive at once from your father's son.\n",
|
981 |
"\n",
|
982 |
"GLOUCESTER:\n",
|
983 |
+
"\n",
|
984 |
+
"CLARENCE:\n",
|
985 |
+
"That is the field to be put in his name.\n",
|
986 |
+
"\n",
|
987 |
+
"KING RICHARD III:\n",
|
988 |
+
"Ay, sir.\n",
|
989 |
+
"\n",
|
990 |
+
"QUEEN ELIZABETH:\n",
|
991 |
+
"How farewell.\n",
|
992 |
+
"\n",
|
993 |
+
"POLIXENES:\n",
|
994 |
+
"She hath stand and better me for you.\n",
|
995 |
+
"\n",
|
996 |
+
"QUEEN ELIZABETH:\n",
|
997 |
+
"I will be patient me.\n",
|
998 |
+
"\n",
|
999 |
+
"CLARENCE:\n",
|
1000 |
+
"The been are the lawful strength, and he made\n",
|
1001 |
+
"That me was upon the which he comes with a holy face.\n",
|
1002 |
"\n",
|
1003 |
"BUCKINGHAM:\n",
|
1004 |
+
"The noble and of my tent.\n",
|
1005 |
+
"\n",
|
1006 |
+
"GLOUCESTER:\n",
|
1007 |
+
"My gracious lord, he did soul the news,\n",
|
1008 |
+
"Which he had of pretty with the fair traitor\n",
|
1009 |
+
"Than a service hath some said for his prince.\n",
|
1010 |
+
"\n",
|
1011 |
+
"QUEEN MARGARET:\n",
|
1012 |
+
"Stay the sad is mine execution,\n",
|
1013 |
+
"Which can slain with you have some so death;\n",
|
1014 |
+
"For I have sent done the many of the city\n",
|
1015 |
+
"That you more so much hold me speak.\n",
|
1016 |
+
"\n",
|
1017 |
+
"GLOUCESTER:\n",
|
1018 |
+
"\n",
|
1019 |
+
"GLOUCESTER:\n",
|
1020 |
+
"So doth fair lord, that be straight in his wit\n",
|
1021 |
+
"The matter, the city distance.\n",
|
1022 |
+
"\n",
|
1023 |
+
"LEONTES:\n",
|
1024 |
+
"The brother about the fire--thou sha\n",
|
1025 |
+
"If you do I am sure be so princely false;\n",
|
1026 |
+
"And the prince may do me for the house.\n",
|
1027 |
+
"\n",
|
1028 |
+
"GLOUCESTER:\n",
|
1029 |
+
"Why, what wilt me, when we strike a mine eyes;\n",
|
1030 |
+
"Of he doth deficers, in my fair deeds,\n",
|
1031 |
+
"That thou wert from me hence his pardon lives.\n",
|
1032 |
+
"\n",
|
1033 |
+
"LORD FITZWATER:\n",
|
1034 |
+
"Say, I will be ground to some of my body;\n",
|
1035 |
+
"And what thou wilt part so supposed with him.\n",
|
1036 |
+
"\n",
|
1037 |
+
"CLARENCE:\n",
|
1038 |
+
"You might be breathe in me?\n",
|
1039 |
+
"\n",
|
1040 |
+
"CAMILLO:\n",
|
1041 |
+
"Stay now the first is no more resolved\n",
|
1042 |
+
"With all the lawful and shed the gates word.\n",
|
1043 |
+
"\n",
|
1044 |
+
"CLARENCE:\n",
|
1045 |
+
"Thou that, when my fair that I would not for heaven,\n",
|
1046 |
+
"That fortune but not like her for the mind,\n",
|
1047 |
+
"When I do be my tongues and her love.\n",
|
1048 |
+
"\n",
|
1049 |
+
"ANGELO:\n",
|
1050 |
+
"She shall we not have done by his house.\n",
|
1051 |
+
"\n",
|
1052 |
+
"HENRY BOLINGBROKE:\n",
|
1053 |
+
"Have you no more by my fortune should as heard.\n",
|
1054 |
+
"\n",
|
1055 |
+
"QUEEN ELIZABETH:\n",
|
1056 |
+
"O, bring forth thy brother Gloucester,\n",
|
1057 |
+
"But both an earthly of love from spoke me.\n",
|
1058 |
+
"\n",
|
1059 |
+
"QUEEN ELIZABETH:\n",
|
1060 |
+
"Thou art thou as you shalt be dish'd in their head.\n",
|
1061 |
+
"\n",
|
1062 |
+
"Second Murderer:\n",
|
1063 |
+
"The prince and my execution and heard it.\n",
|
1064 |
+
"\n",
|
1065 |
+
"KING RICHARD III:\n",
|
1066 |
+
"Say, well I am a house in my heart;\n",
|
1067 |
+
"But I am"
|
1068 |
]
|
1069 |
}
|
1070 |
],
|
1071 |
"source": [
|
1072 |
+
"import gradio as gr\n",
|
1073 |
"\n",
|
1074 |
+
"demo = gr.Interface(\n",
|
1075 |
+
" fn=generate, \n",
|
1076 |
+
" inputs=[\n",
|
1077 |
+
" gr.Textbox(lines=2, placeholder=\"Prompt here...\"),\n",
|
1078 |
+
" gr.Number(value=256),\n",
|
1079 |
+
" gr.Number(value=0.8),\n",
|
1080 |
+
" gr.Slider(maximum=128,value=10),\n",
|
1081 |
+
" gr.Slider(maximum=1,value=1)\n",
|
1082 |
+
" ], \n",
|
1083 |
+
" outputs=\"text\",\n",
|
1084 |
+
" title=\"Shakespeare-GPT\",\n",
|
1085 |
+
" description=\"Putting theater kids out of their nonexistent jobs since 2023\"\n",
|
1086 |
+
")\n",
|
1087 |
"\n",
|
1088 |
+
"demo.launch()\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1089 |
]
|
1090 |
}
|
1091 |
],
|
gpt.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gradio as gr
|
3 |
+
import gc
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.optim as optim
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
import math
|
15 |
+
from typing import Type
|
16 |
+
|
17 |
+
|
18 |
+
# Define the command-line arguments
|
19 |
+
parser = argparse.ArgumentParser(description='GPT CLI')
|
20 |
+
parser.add_argument('--gui', action='store_true', help='Enable Gradio UI mode')
|
21 |
+
parser.add_argument('--config', default='./gpt_config.json',
|
22 |
+
help='Path to the config file')
|
23 |
+
subparsers = parser.add_subparsers(dest='command', help='Choose a command')
|
24 |
+
|
25 |
+
# Define the training command
|
26 |
+
train_parser = subparsers.add_parser('train', help='Train the model')
|
27 |
+
train_parser.add_argument('--load-from-restore', action='store_true',
|
28 |
+
help='Load from restore path instead of training from scratch')
|
29 |
+
|
30 |
+
# Define the evaluation command
|
31 |
+
eval_parser = subparsers.add_parser('eval', help='Evaluate the model')
|
32 |
+
eval_parser.add_argument('--data', default='./data/evaluation_data.txt',
|
33 |
+
help='Path to the evaluation data file')
|
34 |
+
|
35 |
+
# Define the inference command
|
36 |
+
infer_parser = subparsers.add_parser(
|
37 |
+
'infer', help='Generate text from the model')
|
38 |
+
infer_parser.add_argument('--text', type=str, required=True,
|
39 |
+
help='Input text for generating continuation')
|
40 |
+
infer_parser.add_argument('--length', type=int,
|
41 |
+
default=100, help='Number of characters to generate')
|
42 |
+
|
43 |
+
torch.manual_seed(1337)
|
44 |
+
# Set device to CUDA if available, otherwise use CPU
|
45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
|
47 |
+
|
48 |
+
class GPTConfig:
|
49 |
+
def __init__(self, config_file_path):
|
50 |
+
with open(config_file_path, 'r') as f:
|
51 |
+
config = json.load(f)
|
52 |
+
|
53 |
+
# Architecture configuration
|
54 |
+
architecture_config = config['architecture']
|
55 |
+
self.embedding_dim = architecture_config['embedding_dim']
|
56 |
+
self.vocab_size = architecture_config['vocab_size']
|
57 |
+
self.context_size = architecture_config['context_size']
|
58 |
+
self.num_heads = architecture_config['num_heads']
|
59 |
+
self.num_layers = architecture_config['num_layers']
|
60 |
+
|
61 |
+
# Training configuration
|
62 |
+
training_config = config['training']
|
63 |
+
self.batch_size = training_config['batch_size']
|
64 |
+
self.training_data_path = training_config['training_data_path']
|
65 |
+
self.learning_rate = training_config['learning_rate']
|
66 |
+
self.num_steps = training_config['num_steps']
|
67 |
+
self.val_interval = training_config['val_interval']
|
68 |
+
|
69 |
+
# Checkpoint restore configuration
|
70 |
+
self.restore_path = config['restore_path']
|
71 |
+
|
72 |
+
|
73 |
+
def encode_text(text):
|
74 |
+
# Simple dumb ASCII character-level "encoding" since all training data is ASCII.
|
75 |
+
# Consider better tokenization if moving off character-level
|
76 |
+
return ([ord(t) for t in text])
|
77 |
+
|
78 |
+
|
79 |
+
def decode_text(indices):
|
80 |
+
return ([chr(x) for x in indices])
|
81 |
+
|
82 |
+
|
83 |
+
class TextDataset(Dataset):
|
84 |
+
def __init__(self, data_tensor, context_size):
|
85 |
+
self.data_tensor = data_tensor
|
86 |
+
self.context_size = context_size
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.data_tensor) - self.context_size
|
90 |
+
|
91 |
+
def __getitem__(self, index):
|
92 |
+
x = self.data_tensor[index:index + self.context_size]
|
93 |
+
y = self.data_tensor[index + 1:index + self.context_size + 1]
|
94 |
+
|
95 |
+
return x, y
|
96 |
+
|
97 |
+
|
98 |
+
def load_dataset(data_path, val, context_size):
|
99 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
100 |
+
text = f.read()
|
101 |
+
|
102 |
+
# Tensorify data, put it in dataset
|
103 |
+
data = torch.tensor(encode_text(text), dtype=torch.int32)
|
104 |
+
|
105 |
+
test_split_idx = int(0.8 * len(data))
|
106 |
+
val_split_idx = int(0.9 * len(data))
|
107 |
+
train_data = data[:test_split_idx]
|
108 |
+
test_data = data[test_split_idx:val_split_idx]
|
109 |
+
val_data = data[val_split_idx:]
|
110 |
+
# print(f"{len(data)} chars of data")
|
111 |
+
|
112 |
+
train_dataset = TextDataset(train_data, context_size)
|
113 |
+
test_dataset = TextDataset(test_data, context_size)
|
114 |
+
val_dataset = TextDataset(test_data, context_size)
|
115 |
+
return ((train_dataset, test_dataset, val_dataset))
|
116 |
+
|
117 |
+
|
118 |
+
class MultiheadAttention(nn.Module):
|
119 |
+
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, device=None, dtype=None):
|
120 |
+
super(MultiheadAttention, self).__init__()
|
121 |
+
|
122 |
+
# Save variables
|
123 |
+
self.embed_dim = embed_dim
|
124 |
+
self.num_heads = num_heads
|
125 |
+
self.d_k = embed_dim // num_heads
|
126 |
+
|
127 |
+
self.Q = nn.Linear(embed_dim, embed_dim, bias=False)
|
128 |
+
self.K = nn.Linear(embed_dim, embed_dim, bias=False)
|
129 |
+
self.V = nn.Linear(embed_dim, embed_dim, bias=False)
|
130 |
+
|
131 |
+
self.dropout = nn.Dropout(dropout)
|
132 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
133 |
+
|
134 |
+
def forward(self, query, key, value, attn_mask=None):
|
135 |
+
batch_size = query.size(0)
|
136 |
+
|
137 |
+
# Apply linear layers
|
138 |
+
q = self.Q(query) # [B, C, E]
|
139 |
+
k = self.K(key) # [B, C, E]
|
140 |
+
v = self.V(value) # [B, C, E]
|
141 |
+
|
142 |
+
# Mutate dimensions so the attention matmul can get rid of the inner d_k
|
143 |
+
# [batch_size, num_heads, C, d_k]
|
144 |
+
q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
145 |
+
# [batch_size, num_heads, C, d_k]
|
146 |
+
k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
147 |
+
# [batch_size, num_heads, C, d_k]
|
148 |
+
v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
149 |
+
|
150 |
+
# Get raw attention scores
|
151 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / \
|
152 |
+
math.sqrt(self.d_k) # [B, num_heads, C, C]
|
153 |
+
|
154 |
+
# Apply mask, if necessary
|
155 |
+
if attn_mask is not None:
|
156 |
+
scores = scores.masked_fill(attn_mask, float('-inf'))
|
157 |
+
|
158 |
+
# Scale by sqrt(k)
|
159 |
+
attn = F.softmax(scores, dim=-1)
|
160 |
+
attn = self.dropout(attn)
|
161 |
+
out = attn @ v # [B, num_heads, C, d_k]
|
162 |
+
|
163 |
+
# Concat and project
|
164 |
+
# Swap C and num_heads, force memory to coalesce, then fuse back num_heads and d_k together
|
165 |
+
out = out.transpose(1, 2).contiguous().view(
|
166 |
+
batch_size, -1, self.embed_dim)
|
167 |
+
# Project: give attention "time to think". Maybe this should be part of a different module but whatever
|
168 |
+
out = self.out_proj(out)
|
169 |
+
return ((out, None))
|
170 |
+
|
171 |
+
|
172 |
+
class FeedForward(nn.Module):
|
173 |
+
def __init__(self, embed_dim, dropout):
|
174 |
+
super().__init__()
|
175 |
+
self.net = nn.Sequential(
|
176 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
177 |
+
nn.GELU(),
|
178 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
179 |
+
nn.Dropout(dropout),
|
180 |
+
)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
return (self.net(x))
|
184 |
+
|
185 |
+
|
186 |
+
class Block(nn.Module):
|
187 |
+
"""Self-attention"""
|
188 |
+
|
189 |
+
def __init__(self, embed_dim, num_heads, mask, dropout=0.2):
|
190 |
+
super(Block, self).__init__()
|
191 |
+
self.register_buffer("mask", mask)
|
192 |
+
self.head = MultiheadAttention(
|
193 |
+
embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
|
194 |
+
# self.head = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
|
195 |
+
self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)
|
196 |
+
self.ln1 = nn.LayerNorm(embed_dim)
|
197 |
+
self.ln2 = nn.LayerNorm(embed_dim)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
# Residual connections
|
201 |
+
x = self.ln1(x)
|
202 |
+
attn_output, _ = self.head(x, x, x, attn_mask=self.mask)
|
203 |
+
x = x + attn_output
|
204 |
+
out = x + self.ffwd(self.ln2(x))
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class GPT(nn.Module):
|
209 |
+
def __init__(self, embedding_dim, vocab_size, context_size):
|
210 |
+
super(GPT, self).__init__()
|
211 |
+
|
212 |
+
self.embedding_dim = embedding_dim
|
213 |
+
self.output_dim = vocab_size
|
214 |
+
self.context_size = context_size
|
215 |
+
|
216 |
+
NUM_HEADS = 4
|
217 |
+
NUM_LAYERS = 4
|
218 |
+
|
219 |
+
# Initialize layers
|
220 |
+
self.tok_embed = nn.Embedding(vocab_size, embedding_dim)
|
221 |
+
self.pos_embed = nn.Embedding(context_size, embedding_dim)
|
222 |
+
|
223 |
+
mask = torch.tril(torch.ones(
|
224 |
+
self.context_size, self.context_size)).bool()
|
225 |
+
mask = ~mask
|
226 |
+
self.register_buffer("mask", mask)
|
227 |
+
|
228 |
+
self.blocks = nn.Sequential(
|
229 |
+
*[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)]
|
230 |
+
)
|
231 |
+
|
232 |
+
self.ln_f = nn.LayerNorm(self.embedding_dim)
|
233 |
+
# Final feed-forward layer from embeddings
|
234 |
+
self.ffwd = nn.Linear(
|
235 |
+
embedding_dim, out_features=vocab_size, bias=False)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
tok_embed = self.tok_embed(x)
|
239 |
+
pos_embed = self.pos_embed(
|
240 |
+
torch.arange(0, self.context_size, device="cuda")
|
241 |
+
)
|
242 |
+
x = tok_embed + pos_embed
|
243 |
+
|
244 |
+
x = self.blocks(x)
|
245 |
+
x = self.ln_f(x)
|
246 |
+
|
247 |
+
logits = self.ffwd(x)
|
248 |
+
return (logits)
|
249 |
+
|
250 |
+
def infer(self, x):
|
251 |
+
with torch.no_grad():
|
252 |
+
self.eval()
|
253 |
+
res = self.forward(x)
|
254 |
+
return (res)
|
255 |
+
|
256 |
+
def num_params(self):
|
257 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
258 |
+
|
259 |
+
|
260 |
+
def load_checkpoint(model, optimizer, path):
|
261 |
+
"""
|
262 |
+
Loads a saved checkpoint file into the model and optimizer.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
model (nn.Module): The PyTorch model to load the checkpoint into.
|
266 |
+
optimizer (torch.optim.Optimizer): The PyTorch optimizer to load the checkpoint into.
|
267 |
+
path (str): The path to the saved checkpoint file.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Tuple[nn.Module, torch.optim.Optimizer]: The model and optimizer, loaded with the checkpoint state.
|
271 |
+
"""
|
272 |
+
checkpoint = torch.load(path)
|
273 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
274 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
275 |
+
return (model, optimizer)
|
276 |
+
|
277 |
+
|
278 |
+
def save_checkpoint(model, optimizer, path, steps):
|
279 |
+
"""
|
280 |
+
Saves a checkpoint of the model and optimizer to disk.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
model (nn.Module): The PyTorch model to save the checkpoint of.
|
284 |
+
optimizer (torch.optim.Optimizer): The PyTorch optimizer to save the checkpoint of.
|
285 |
+
path (str): The path to save the checkpoint file.
|
286 |
+
steps (int): The number of training steps that have been completed.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
None
|
290 |
+
"""
|
291 |
+
torch.save({
|
292 |
+
'steps': steps,
|
293 |
+
'model_state_dict': model.state_dict(),
|
294 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
295 |
+
}, path)
|
296 |
+
|
297 |
+
|
298 |
+
def compute_loss(model, criterion, x, y):
|
299 |
+
logits = model(x)
|
300 |
+
B, C, V = logits.shape
|
301 |
+
logits = logits.view(B*C, V)
|
302 |
+
y = y.view(B*C)
|
303 |
+
loss = F.cross_entropy(logits, y.long())
|
304 |
+
return loss
|
305 |
+
|
306 |
+
|
307 |
+
def train(model, optimizer, config: Type[GPTConfig]):
|
308 |
+
model = model.to(device)
|
309 |
+
criterion = F.cross_entropy
|
310 |
+
|
311 |
+
global_step = 0
|
312 |
+
|
313 |
+
train_dataset, val_dataset = load_dataset(
|
314 |
+
config.training_data_path, None, model.context_size)
|
315 |
+
|
316 |
+
train_dataloader = DataLoader(
|
317 |
+
train_dataset,
|
318 |
+
batch_size=config.batch_size,
|
319 |
+
shuffle=True,
|
320 |
+
num_workers=4
|
321 |
+
)
|
322 |
+
|
323 |
+
test_dataloader = DataLoader(
|
324 |
+
test_dataset, batch_size=512, num_workers=4, shuffle=True)
|
325 |
+
|
326 |
+
model.train()
|
327 |
+
|
328 |
+
EPOCHS = 1
|
329 |
+
STEPS = config.num_steps
|
330 |
+
VAL_INTERVAL = 100
|
331 |
+
|
332 |
+
writer = SummaryWriter()
|
333 |
+
|
334 |
+
step = 0
|
335 |
+
|
336 |
+
for epoch in range(EPOCHS):
|
337 |
+
for data, target in train_dataloader:
|
338 |
+
data = data.to(device)
|
339 |
+
target = target.to(device)
|
340 |
+
|
341 |
+
loss = compute_loss(model, criterion, data, target)
|
342 |
+
|
343 |
+
# Backward pass
|
344 |
+
optimizer.zero_grad()
|
345 |
+
loss.backward()
|
346 |
+
optimizer.step()
|
347 |
+
|
348 |
+
writer.add_scalar(
|
349 |
+
"Loss/train", loss.cpu().detach().numpy(), global_step)
|
350 |
+
global_step += 1
|
351 |
+
|
352 |
+
if step % VAL_INTERVAL == 0:
|
353 |
+
total_loss = 0
|
354 |
+
total_samples = 0
|
355 |
+
|
356 |
+
with torch.no_grad():
|
357 |
+
model.eval()
|
358 |
+
for x, y in test_dataloader:
|
359 |
+
x = x.to(device)
|
360 |
+
y = y.to(device)
|
361 |
+
|
362 |
+
batch_loss = compute_loss(model, criterion, x, y)
|
363 |
+
total_loss += batch_loss.item() * 512
|
364 |
+
total_samples += 512
|
365 |
+
if total_samples > 10:
|
366 |
+
break
|
367 |
+
|
368 |
+
model.train()
|
369 |
+
average_loss = total_loss / total_samples
|
370 |
+
|
371 |
+
print(f"Step {step}; loss: {average_loss}")
|
372 |
+
writer.add_scalar("Loss/val", average_loss, global_step)
|
373 |
+
|
374 |
+
step += 1
|
375 |
+
if step >= STEPS:
|
376 |
+
break
|
377 |
+
|
378 |
+
writer.close()
|
379 |
+
|
380 |
+
|
381 |
+
def evaluate_model(model, val_dataset, block_size=512, max_samples=100000):
|
382 |
+
model.eval()
|
383 |
+
total_loss = 0.0
|
384 |
+
total_samples = 0
|
385 |
+
criterion = F.cross_entropy
|
386 |
+
|
387 |
+
val_dataloader = DataLoader(
|
388 |
+
val_dataset, batch_size=block_size, num_workers=4)
|
389 |
+
with torch.no_grad():
|
390 |
+
for inputs, targets in val_dataloader:
|
391 |
+
inputs = inputs.to(device)
|
392 |
+
targets = targets.to(device)
|
393 |
+
|
394 |
+
batch_loss = compute_loss(model, criterion, inputs, targets)
|
395 |
+
total_loss += batch_loss.item() * inputs.size(0)
|
396 |
+
total_samples += inputs.size(0)
|
397 |
+
if total_samples > max_samples:
|
398 |
+
break
|
399 |
+
|
400 |
+
average_loss = total_loss / total_samples
|
401 |
+
return average_loss
|
402 |
+
|
403 |
+
|
404 |
+
def generate(model, config, prompt, gen_length, temp=1, top_k=10, top_p=None):
|
405 |
+
g_cuda = torch.Generator(device=device)
|
406 |
+
contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device)
|
407 |
+
|
408 |
+
model.eval()
|
409 |
+
for i in range(gen_length):
|
410 |
+
transform = nn.LogSoftmax(1)
|
411 |
+
x = contexts[-config.context_size:]
|
412 |
+
if x.size(0) < config.context_size:
|
413 |
+
x = F.pad(x, (config.context_size - x.size(0), 0),
|
414 |
+
"constant", 0).unsqueeze(0) # B*T
|
415 |
+
else:
|
416 |
+
x = x.unsqueeze(0)
|
417 |
+
|
418 |
+
preds = model.infer(x)
|
419 |
+
preds = preds.squeeze(0)
|
420 |
+
preds = preds / temp
|
421 |
+
probs = F.softmax(preds, dim=-1)
|
422 |
+
|
423 |
+
if top_p is not None:
|
424 |
+
# Apply top-p
|
425 |
+
sorted_probs, sorted_indices = torch.sort(
|
426 |
+
probs[-1, :], descending=True)
|
427 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
428 |
+
# find cutoff
|
429 |
+
idx_top_p = (cumulative_probs < top_p).sum().item()
|
430 |
+
top_probs = sorted_probs[:idx_top_p]
|
431 |
+
top_indices = sorted_indices[:idx_top_p]
|
432 |
+
# Null case
|
433 |
+
if top_probs.size(0) == 0:
|
434 |
+
top_probs = sorted_probs[:1]
|
435 |
+
top_indices = sorted_indices[:1]
|
436 |
+
|
437 |
+
next_char = torch.multinomial(
|
438 |
+
top_probs, num_samples=1, generator=g_cuda)
|
439 |
+
next_char = top_indices[next_char]
|
440 |
+
elif top_k is not None:
|
441 |
+
top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k)
|
442 |
+
next_char = torch.multinomial(
|
443 |
+
top_k_probs, num_samples=1, generator=g_cuda)
|
444 |
+
next_char = top_k_indices[next_char]
|
445 |
+
else:
|
446 |
+
next_char = torch.multinomial(
|
447 |
+
probs, num_samples=1, generator=g_cuda)
|
448 |
+
|
449 |
+
contexts = torch.cat((contexts, next_char), dim=0)
|
450 |
+
print(decode_text(next_char.cpu().numpy())[-1], end="")
|
451 |
+
|
452 |
+
return ("".join(decode_text(contexts.cpu().numpy())))
|
453 |
+
|
454 |
+
|
455 |
+
def main():
|
456 |
+
# Parse the command-line arguments
|
457 |
+
args = parser.parse_args()
|
458 |
+
|
459 |
+
config = GPTConfig(args.config)
|
460 |
+
# Create the GPT model
|
461 |
+
model = GPT(
|
462 |
+
vocab_size=config.vocab_size,
|
463 |
+
context_size=config.context_size,
|
464 |
+
embedding_dim=config.embedding_dim
|
465 |
+
)
|
466 |
+
model.to(device)
|
467 |
+
|
468 |
+
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
|
469 |
+
if args.gui:
|
470 |
+
load_checkpoint(model, optimizer, config.restore_path)
|
471 |
+
demo = gr.Interface(
|
472 |
+
fn=lambda *args: generate(model, config, *args),
|
473 |
+
inputs=[
|
474 |
+
gr.Textbox(lines=2, placeholder="Prompt here..."),
|
475 |
+
gr.Number(precision=0, value=256),
|
476 |
+
gr.Number(value=0.8),
|
477 |
+
gr.Slider(maximum=128, value=10),
|
478 |
+
gr.Slider(maximum=1, value=1)
|
479 |
+
],
|
480 |
+
outputs="text",
|
481 |
+
title="Shakespeare-GPT",
|
482 |
+
description="Putting theater kids out of their nonexistent jobs since 2023"
|
483 |
+
)
|
484 |
+
|
485 |
+
demo.launch()
|
486 |
+
elif args.command == "train":
|
487 |
+
if args.load_from_restore:
|
488 |
+
load_checkpoint(model, optimizer, path)
|
489 |
+
|
490 |
+
train(model, config)
|
491 |
+
elif args.command == "eval":
|
492 |
+
_, _, test_dataset = load_dataset(
|
493 |
+
config.training_data_path, None, model.context_size)
|
494 |
+
evaluate_model(model, test_dataset)
|
495 |
+
elif args.command == "infer":
|
496 |
+
prompt = args.text
|
497 |
+
generated_text = generate(model, config, prompt, args.length)
|
498 |
+
print(generated_text)
|
499 |
+
|
500 |
+
|
501 |
+
if __name__ == "__main__":
|
502 |
+
main()
|
gpt_config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architecture": {
|
3 |
+
"embedding_dim": 256,
|
4 |
+
"vocab_size": 128,
|
5 |
+
"context_size": 256,
|
6 |
+
"num_heads": 4,
|
7 |
+
"num_layers": 4
|
8 |
+
},
|
9 |
+
"training": {
|
10 |
+
"batch_size": 64,
|
11 |
+
"training_data_path": "datasets/training_data.txt",
|
12 |
+
"learning_rate": 3e-4,
|
13 |
+
"num_steps": 5000,
|
14 |
+
"val_interval": 100
|
15 |
+
},
|
16 |
+
"restore_path": "checkpoints/model.pt"
|
17 |
+
}
|