jkeisling commited on
Commit
b74a784
·
1 Parent(s): 343c085

WIP Parameterize gpt script, add CLI

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. gpt.ipynb +460 -166
  3. gpt.py +502 -0
  4. gpt_config.json +17 -0
.gitignore CHANGED
@@ -4,6 +4,7 @@ datasets/
4
 
5
  # Training Tensorboard runs
6
  runs/
 
7
 
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
 
4
 
5
  # Training Tensorboard runs
6
  runs/
7
+ flagged/
8
 
9
  # Byte-compiled / optimized / DLL files
10
  __pycache__/
gpt.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 39,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
@@ -22,7 +22,7 @@
22
  },
23
  {
24
  "cell_type": "code",
25
- "execution_count": 40,
26
  "metadata": {},
27
  "outputs": [],
28
  "source": [
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 41,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
@@ -55,90 +55,18 @@
55
  },
56
  {
57
  "cell_type": "code",
58
- "execution_count": 42,
59
  "metadata": {},
60
- "outputs": [
61
- {
62
- "name": "stdout",
63
- "output_type": "stream",
64
- "text": [
65
- "Requirement already satisfied: torch in ./venv/lib/python3.10/site-packages (2.0.0)\n",
66
- "Requirement already satisfied: pandas in ./venv/lib/python3.10/site-packages (1.5.3)\n",
67
- "Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (1.24.1)\n",
68
- "Requirement already satisfied: tensorboard in ./venv/lib/python3.10/site-packages (2.12.0)\n",
69
- "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in ./venv/lib/python3.10/site-packages (from torch) (2.14.3)\n",
70
- "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in ./venv/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
71
- "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in ./venv/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
72
- "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
73
- "Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch) (3.0)\n",
74
- "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in ./venv/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
75
- "Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from torch) (3.10.4)\n",
76
- "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.91)\n",
77
- "Requirement already satisfied: typing-extensions in ./venv/lib/python3.10/site-packages (from torch) (4.5.0)\n",
78
- "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in ./venv/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
79
- "Requirement already satisfied: sympy in ./venv/lib/python3.10/site-packages (from torch) (1.11.1)\n",
80
- "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in ./venv/lib/python3.10/site-packages (from torch) (11.7.99)\n",
81
- "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in ./venv/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
82
- "Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch) (3.1.2)\n",
83
- "Requirement already satisfied: triton==2.0.0 in ./venv/lib/python3.10/site-packages (from torch) (2.0.0)\n",
84
- "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in ./venv/lib/python3.10/site-packages (from torch) (11.7.101)\n",
85
- "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in ./venv/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
86
- "Requirement already satisfied: wheel in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.40.0)\n",
87
- "Requirement already satisfied: setuptools in ./venv/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (65.5.0)\n",
88
- "Requirement already satisfied: lit in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.0)\n",
89
- "Requirement already satisfied: cmake in ./venv/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.26.1)\n",
90
- "Requirement already satisfied: python-dateutil>=2.8.1 in ./venv/lib/python3.10/site-packages (from pandas) (2.8.2)\n",
91
- "Requirement already satisfied: pytz>=2020.1 in ./venv/lib/python3.10/site-packages (from pandas) (2023.2)\n",
92
- "Requirement already satisfied: requests<3,>=2.21.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.28.2)\n",
93
- "Requirement already satisfied: werkzeug>=1.0.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.2.3)\n",
94
- "Requirement already satisfied: google-auth<3,>=1.6.3 in ./venv/lib/python3.10/site-packages (from tensorboard) (2.16.3)\n",
95
- "Requirement already satisfied: protobuf>=3.19.6 in ./venv/lib/python3.10/site-packages (from tensorboard) (4.22.1)\n",
96
- "Requirement already satisfied: markdown>=2.6.8 in ./venv/lib/python3.10/site-packages (from tensorboard) (3.4.3)\n",
97
- "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.4.6)\n",
98
- "Requirement already satisfied: grpcio>=1.48.2 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.51.3)\n",
99
- "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (0.7.0)\n",
100
- "Requirement already satisfied: absl-py>=0.4 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.4.0)\n",
101
- "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in ./venv/lib/python3.10/site-packages (from tensorboard) (1.8.1)\n",
102
- "Requirement already satisfied: cachetools<6.0,>=2.0.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (5.3.0)\n",
103
- "Requirement already satisfied: pyasn1-modules>=0.2.1 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (0.2.8)\n",
104
- "Requirement already satisfied: six>=1.9.0 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (1.16.0)\n",
105
- "Requirement already satisfied: rsa<5,>=3.1.4 in ./venv/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard) (4.9)\n",
106
- "Requirement already satisfied: requests-oauthlib>=0.7.0 in ./venv/lib/python3.10/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (1.3.1)\n",
107
- "Requirement already satisfied: idna<4,>=2.5 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.4)\n",
108
- "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (1.26.15)\n",
109
- "Requirement already satisfied: charset-normalizer<4,>=2 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (3.1.0)\n",
110
- "Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard) (2022.12.7)\n",
111
- "Requirement already satisfied: MarkupSafe>=2.1.1 in ./venv/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.2)\n",
112
- "Requirement already satisfied: mpmath>=0.19 in ./venv/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
113
- "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in ./venv/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.4.8)\n",
114
- "Requirement already satisfied: oauthlib>=3.0.0 in ./venv/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (3.2.2)\n",
115
- "\n",
116
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n",
117
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
118
- "Note: you may need to restart the kernel to use updated packages.\n"
119
- ]
120
- }
121
- ],
122
  "source": [
123
- "%pip install torch pandas numpy tensorboard"
124
  ]
125
  },
126
  {
127
  "cell_type": "code",
128
- "execution_count": 43,
129
  "metadata": {},
130
- "outputs": [
131
- {
132
- "data": {
133
- "text/plain": [
134
- "<torch._C.Generator at 0x7fef50768610>"
135
- ]
136
- },
137
- "execution_count": 43,
138
- "metadata": {},
139
- "output_type": "execute_result"
140
- }
141
- ],
142
  "source": [
143
  "import torch\n",
144
  "import torch.nn as nn\n",
@@ -150,12 +78,14 @@
150
  "import numpy as np\n",
151
  "import math\n",
152
  "\n",
153
- "torch.manual_seed(1337)"
 
 
154
  ]
155
  },
156
  {
157
  "cell_type": "code",
158
- "execution_count": 44,
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
@@ -169,7 +99,7 @@
169
  },
170
  {
171
  "cell_type": "code",
172
- "execution_count": 45,
173
  "metadata": {},
174
  "outputs": [
175
  {
@@ -202,7 +132,7 @@
202
  },
203
  {
204
  "cell_type": "code",
205
- "execution_count": 46,
206
  "metadata": {},
207
  "outputs": [],
208
  "source": [
@@ -231,7 +161,7 @@
231
  },
232
  {
233
  "cell_type": "code",
234
- "execution_count": 66,
235
  "metadata": {},
236
  "outputs": [],
237
  "source": [
@@ -298,7 +228,7 @@
298
  },
299
  {
300
  "cell_type": "code",
301
- "execution_count": 48,
302
  "metadata": {},
303
  "outputs": [],
304
  "source": [
@@ -318,7 +248,7 @@
318
  },
319
  {
320
  "cell_type": "code",
321
- "execution_count": 60,
322
  "metadata": {},
323
  "outputs": [],
324
  "source": [
@@ -344,7 +274,7 @@
344
  },
345
  {
346
  "cell_type": "code",
347
- "execution_count": 50,
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
@@ -404,7 +334,7 @@
404
  },
405
  {
406
  "cell_type": "code",
407
- "execution_count": 51,
408
  "metadata": {},
409
  "outputs": [],
410
  "source": [
@@ -419,7 +349,7 @@
419
  },
420
  {
421
  "cell_type": "code",
422
- "execution_count": 67,
423
  "metadata": {},
424
  "outputs": [],
425
  "source": [
@@ -435,7 +365,7 @@
435
  " context_size=BLOCK_SIZE,\n",
436
  " )\n",
437
  "\n",
438
- "model = model.to('cuda')\n",
439
  "optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
440
  "#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
441
  "criterion = F.cross_entropy\n",
@@ -529,8 +459,8 @@
529
  "\n",
530
  "for epoch in range(EPOCHS):\n",
531
  " for data, target in train_dataloader:\n",
532
- " data = data.to('cuda')\n",
533
- " target = target.to('cuda')\n",
534
  "\n",
535
  " loss = compute_loss(model, criterion, data, target)\n",
536
  "\n",
@@ -551,8 +481,8 @@
551
  " with torch.no_grad():\n",
552
  " model.eval()\n",
553
  " for x, y in test_dataloader:\n",
554
- " x = x.to(\"cuda\")\n",
555
- " y = y.to(\"cuda\")\n",
556
  "\n",
557
  " batch_loss = compute_loss(model, criterion, x, y)\n",
558
  " total_loss += batch_loss.item() * 512\n",
@@ -576,7 +506,7 @@
576
  },
577
  {
578
  "cell_type": "code",
579
- "execution_count": 69,
580
  "metadata": {},
581
  "outputs": [],
582
  "source": [
@@ -600,7 +530,7 @@
600
  },
601
  {
602
  "cell_type": "code",
603
- "execution_count": 18,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
@@ -619,16 +549,16 @@
619
  },
620
  {
621
  "cell_type": "code",
622
- "execution_count": 26,
623
  "metadata": {},
624
  "outputs": [
625
  {
626
  "data": {
627
  "text/plain": [
628
- "841"
629
  ]
630
  },
631
- "execution_count": 26,
632
  "metadata": {},
633
  "output_type": "execute_result"
634
  }
@@ -660,8 +590,8 @@
660
  "val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
661
  "with torch.no_grad():\n",
662
  " for x, y in val_dataloader:\n",
663
- " x = x.to(\"cuda\")\n",
664
- " y = y.to(\"cuda\")\n",
665
  "\n",
666
  " batch_loss = compute_loss(model, criterion, x, y)\n",
667
  " total_loss += batch_loss.item() * x.size(0)\n",
@@ -704,94 +634,458 @@
704
  },
705
  {
706
  "cell_type": "code",
707
- "execution_count": 58,
708
  "metadata": {},
709
  "outputs": [
710
  {
711
  "name": "stdout",
712
  "output_type": "stream",
713
  "text": [
714
- "Tutus, to Marcius, noble Marcius\n",
715
- "Made to my voices! doing and hangs upon them!\n",
716
- "Take it to down our foes and hates with stain,\n",
717
- "Which thus follows slay with on I meland,\n",
718
- "What I am after her to her fearful haunt it?\n",
719
  "\n",
720
- "PAULINA:\n",
721
- "But you are well to hold the king.\n",
 
722
  "\n",
723
- "ISABELLA:\n",
724
- "And I will not go royalty to thy hand.\n",
 
 
725
  "\n",
726
- "LUCIO:\n",
727
- "Since I do not well in such goodly talk of.\n",
728
- "I think I have a stay of it!\n",
729
  "\n",
730
- "HENRY BOLINGBROKE:\n",
731
- "Who say I hate been a day's mind;\n",
732
- "Till we here and so very little and way,\n",
733
- "And wash the city has nest seen the feast.\n",
 
 
734
  "\n",
735
- "DUCHESS OF YORK:\n",
736
- "No, by the matter.\n",
737
  "\n",
738
- "ISABELLA:\n",
739
- "Flitter than desire never yet looks so.\n",
 
 
 
740
  "\n",
741
- "HENRY BOLINGBROKE:\n",
742
- "I am not possible perceived\n",
743
- "And both place, where I may not rafes,\n",
744
- "And like me one air. What you'll your love day?\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  "\n",
746
- "KING RICHARD II:\n",
747
- "Then be thou--\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  "\n",
749
  "GLOUCESTER:\n",
750
- "No, Lord Hastings:\n",
751
- "Else queen, though my trowbers grands me to-morrow\n",
752
- "Here to Bolingbroke's match;\n",
753
- "When the your life and spur at homely speak.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  "\n",
755
  "BUCKINGHAM:\n",
756
- "My father was I follow: if you be your your kingdom,\n",
757
- "My approbations an"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  ]
759
  }
760
  ],
761
  "source": [
762
- "g_cuda = torch.Generator(device='cuda')\n",
763
  "\n",
764
- "seed = \"\"\"\n",
765
- "Plot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\n",
766
- "\"\"\"\n",
767
- "\n",
768
- "contexts = torch.tensor(encode_text(seed), dtype=torch.int32).to('cuda')\n",
769
- "GEN_LENGTH=1024\n",
 
 
 
 
 
 
 
770
  "\n",
771
- "model.eval()\n",
772
- "for i in range(GEN_LENGTH):\n",
773
- " transform = nn.LogSoftmax(1)\n",
774
- " # What happens if GEN_LENGTH > CONTEXT? don't worry about it\n",
775
- " #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), \"constant\", 0)\n",
776
- " x = contexts[-BLOCK_SIZE:]\n",
777
- " if x.size(0) < BLOCK_SIZE:\n",
778
- " x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), \"constant\", 0).unsqueeze(0) # B*T\n",
779
- " else:\n",
780
- " x = x.unsqueeze(0)\n",
781
- "\n",
782
- " preds = model.infer(x)\n",
783
- " preds = preds.squeeze(0)\n",
784
- " probs = torch.softmax(preds, dim=-1)\n",
785
- "\n",
786
- " # TODO: Broken because of bug with the trailing 0s. FIX THIS\n",
787
- " # next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)\n",
788
- " next_char = torch.multinomial(torch.exp(preds[-1, :]), num_samples=1, generator=g_cuda)\n",
789
- "\n",
790
- " #context = torch.cat(context, next_char)\n",
791
- " contexts = torch.cat((contexts, next_char), dim=0)\n",
792
- " print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
793
- "\n",
794
- "#print(\"\".join(decode_text(contexts.cpu().numpy())))"
795
  ]
796
  }
797
  ],
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 2,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
 
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": 3,
26
  "metadata": {},
27
  "outputs": [],
28
  "source": [
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 4,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
 
55
  },
56
  {
57
  "cell_type": "code",
58
+ "execution_count": 5,
59
  "metadata": {},
60
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  "source": [
62
+ "#%pip install torch pandas numpy tensorboard gradio"
63
  ]
64
  },
65
  {
66
  "cell_type": "code",
67
+ "execution_count": 6,
68
  "metadata": {},
69
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
70
  "source": [
71
  "import torch\n",
72
  "import torch.nn as nn\n",
 
78
  "import numpy as np\n",
79
  "import math\n",
80
  "\n",
81
+ "torch.manual_seed(1337)\n",
82
+ "# Set device to CUDA if available, otherwise use CPU\n",
83
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
+ "execution_count": 7,
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
 
99
  },
100
  {
101
  "cell_type": "code",
102
+ "execution_count": 8,
103
  "metadata": {},
104
  "outputs": [
105
  {
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 9,
136
  "metadata": {},
137
  "outputs": [],
138
  "source": [
 
161
  },
162
  {
163
  "cell_type": "code",
164
+ "execution_count": 10,
165
  "metadata": {},
166
  "outputs": [],
167
  "source": [
 
228
  },
229
  {
230
  "cell_type": "code",
231
+ "execution_count": 11,
232
  "metadata": {},
233
  "outputs": [],
234
  "source": [
 
248
  },
249
  {
250
  "cell_type": "code",
251
+ "execution_count": 12,
252
  "metadata": {},
253
  "outputs": [],
254
  "source": [
 
274
  },
275
  {
276
  "cell_type": "code",
277
+ "execution_count": 13,
278
  "metadata": {},
279
  "outputs": [],
280
  "source": [
 
334
  },
335
  {
336
  "cell_type": "code",
337
+ "execution_count": 14,
338
  "metadata": {},
339
  "outputs": [],
340
  "source": [
 
349
  },
350
  {
351
  "cell_type": "code",
352
+ "execution_count": 15,
353
  "metadata": {},
354
  "outputs": [],
355
  "source": [
 
365
  " context_size=BLOCK_SIZE,\n",
366
  " )\n",
367
  "\n",
368
+ "model = model.to(device)\n",
369
  "optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
370
  "#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)\n",
371
  "criterion = F.cross_entropy\n",
 
459
  "\n",
460
  "for epoch in range(EPOCHS):\n",
461
  " for data, target in train_dataloader:\n",
462
+ " data = data.to(device)\n",
463
+ " target = target.to(device)\n",
464
  "\n",
465
  " loss = compute_loss(model, criterion, data, target)\n",
466
  "\n",
 
481
  " with torch.no_grad():\n",
482
  " model.eval()\n",
483
  " for x, y in test_dataloader:\n",
484
+ " x = x.to(device)\n",
485
+ " y = y.to(device)\n",
486
  "\n",
487
  " batch_loss = compute_loss(model, criterion, x, y)\n",
488
  " total_loss += batch_loss.item() * 512\n",
 
506
  },
507
  {
508
  "cell_type": "code",
509
+ "execution_count": 16,
510
  "metadata": {},
511
  "outputs": [],
512
  "source": [
 
530
  },
531
  {
532
  "cell_type": "code",
533
+ "execution_count": 17,
534
  "metadata": {},
535
  "outputs": [],
536
  "source": [
 
549
  },
550
  {
551
  "cell_type": "code",
552
+ "execution_count": 21,
553
  "metadata": {},
554
  "outputs": [
555
  {
556
  "data": {
557
  "text/plain": [
558
+ "3"
559
  ]
560
  },
561
+ "execution_count": 21,
562
  "metadata": {},
563
  "output_type": "execute_result"
564
  }
 
590
  "val_dataloader = DataLoader(val_dataset, batch_size=512, num_workers=4)\n",
591
  "with torch.no_grad():\n",
592
  " for x, y in val_dataloader:\n",
593
+ " x = x.to(device)\n",
594
+ " y = y.to(device)\n",
595
  "\n",
596
  " batch_loss = compute_loss(model, criterion, x, y)\n",
597
  " total_loss += batch_loss.item() * x.size(0)\n",
 
634
  },
635
  {
636
  "cell_type": "code",
637
+ "execution_count": 54,
638
  "metadata": {},
639
  "outputs": [
640
  {
641
  "name": "stdout",
642
  "output_type": "stream",
643
  "text": [
 
 
 
 
 
644
  "\n",
645
+ "Shepherd:\n",
646
+ "What should be gone to your deeds to prince\n",
647
+ "That let us away the house of Lancasters.\n",
648
  "\n",
649
+ "Second Murderer:\n",
650
+ "Alas, if I be not the fire rogued man\n",
651
+ "That hath dead by the fault of the good stars;\n",
652
+ "And the business of some many that be sounded.\n",
653
  "\n",
654
+ "AUTOLYCUS:\n",
655
+ "My soul are but not so dishonours.\n",
 
656
  "\n",
657
+ "BRUTUS:\n",
658
+ "I will do not be point, but it is a soul,\n",
659
+ "I would I see my heart fair state,\n",
660
+ "That he is in the grief and be contented\n",
661
+ "That you may made the father, the day doth nothing\n",
662
+ "And lie at like in his world moved me at him.\n",
663
  "\n",
664
+ "First Senator:\n",
665
+ "The rotten far the points of our honour.\n",
666
  "\n",
667
+ "AUTOLYCUS:\n",
668
+ "The king hath been you think the duke of the foes,\n",
669
+ "And had so long and confessal abroad,\n",
670
+ "And said in the light save of men so meeting\n",
671
+ "May not shall do be as we as here.\n",
672
  "\n",
673
+ "Second Servingman:\n",
674
+ "You are grace as the child and fight again,\n",
675
+ "If you do not shall not be much of the thing;\n",
676
+ "One but so the creature of the charge is were\n",
677
+ "Than the point of your several motion\n",
678
+ "May seem as her like a pride and as would all grave\n",
679
+ "The which of the blood"
680
+ ]
681
+ },
682
+ {
683
+ "data": {
684
+ "text/plain": [
685
+ "'\\nPlot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\\n\\nShepherd:\\nWhat should be gone to your deeds to prince\\nThat let us away the house of Lancasters.\\n\\nSecond Murderer:\\nAlas, if I be not the fire rogued man\\nThat hath dead by the fault of the good stars;\\nAnd the business of some many that be sounded.\\n\\nAUTOLYCUS:\\nMy soul are but not so dishonours.\\n\\nBRUTUS:\\nI will do not be point, but it is a soul,\\nI would I see my heart fair state,\\nThat he is in the grief and be contented\\nThat you may made the father, the day doth nothing\\nAnd lie at like in his world moved me at him.\\n\\nFirst Senator:\\nThe rotten far the points of our honour.\\n\\nAUTOLYCUS:\\nThe king hath been you think the duke of the foes,\\nAnd had so long and confessal abroad,\\nAnd said in the light save of men so meeting\\nMay not shall do be as we as here.\\n\\nSecond Servingman:\\nYou are grace as the child and fight again,\\nIf you do not shall not be much of the thing;\\nOne but so the creature of the charge is were\\nThan the point of your several motion\\nMay seem as her like a pride and as would all grave\\nThe which of the blood'"
686
+ ]
687
+ },
688
+ "execution_count": 54,
689
+ "metadata": {},
690
+ "output_type": "execute_result"
691
+ }
692
+ ],
693
+ "source": [
694
+ "seed = \"\"\"\n",
695
+ "Plot histograms of the gradient values during training. If you notice a significant number of gradients are near zero (vanishing gradients) or very large values (exploding gradients), it could be a problem. TensorBoard is a useful tool for visualizing these histograms.\n",
696
+ "\"\"\"\n",
697
+ "GEN_LENGTH=1024\n",
698
+ "\n",
699
+ "def generate(prompt, gen_length, temp=1, top_k=10, top_p=None):\n",
700
+ " g_cuda = torch.Generator(device=device)\n",
701
+ " contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device)\n",
702
+ "\n",
703
+ " model.eval()\n",
704
+ " for i in range(GEN_LENGTH):\n",
705
+ " transform = nn.LogSoftmax(1)\n",
706
+ " x = contexts[-BLOCK_SIZE:]\n",
707
+ " if x.size(0) < BLOCK_SIZE:\n",
708
+ " x = F.pad(x, (BLOCK_SIZE - x.size(0), 0), \"constant\", 0).unsqueeze(0) # B*T\n",
709
+ " else:\n",
710
+ " x = x.unsqueeze(0)\n",
711
+ "\n",
712
+ " preds = model.infer(x)\n",
713
+ " preds = preds.squeeze(0)\n",
714
+ " preds = preds / temp\n",
715
+ " probs = F.softmax(preds, dim=-1)\n",
716
+ "\n",
717
+ " if top_p is not None:\n",
718
+ " # Apply top-p\n",
719
+ " sorted_probs, sorted_indices = torch.sort(probs[-1, :], descending=True)\n",
720
+ " cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n",
721
+ " # find cutoff\n",
722
+ " idx_top_p = (cumulative_probs < top_p).sum().item()\n",
723
+ " top_probs = sorted_probs[:idx_top_p]\n",
724
+ " top_indices = sorted_indices[:idx_top_p]\n",
725
+ " # Null case\n",
726
+ " if top_probs.size(0) == 0:\n",
727
+ " top_probs = sorted_probs[:1]\n",
728
+ " top_indices = sorted_indices[:1]\n",
729
+ " \n",
730
+ " next_char = torch.multinomial(top_probs, num_samples=1, generator=g_cuda)\n",
731
+ " next_char = top_indices[next_char]\n",
732
+ " elif top_k is not None:\n",
733
+ " top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k)\n",
734
+ " next_char = torch.multinomial(top_k_probs, num_samples=1, generator=g_cuda)\n",
735
+ " next_char = top_k_indices[next_char]\n",
736
+ " else:\n",
737
+ " next_char = torch.multinomial(probs, num_samples=1, generator=g_cuda)\n",
738
+ "\n",
739
+ "\n",
740
+ " contexts = torch.cat((contexts, next_char), dim=0)\n",
741
+ " print(decode_text(next_char.cpu().numpy())[-1], end=\"\")\n",
742
+ " \n",
743
+ " return(\"\".join(decode_text(contexts.cpu().numpy())))\n",
744
+ "\n",
745
+ "\"\".join(generate(seed, GEN_LENGTH,temp=0.8,top_p=0.9))"
746
+ ]
747
+ },
748
+ {
749
+ "attachments": {},
750
+ "cell_type": "markdown",
751
+ "metadata": {},
752
+ "source": [
753
+ "## Gradio WebUI"
754
+ ]
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": 55,
759
+ "metadata": {},
760
+ "outputs": [
761
+ {
762
+ "name": "stdout",
763
+ "output_type": "stream",
764
+ "text": [
765
+ "Running on local URL: http://127.0.0.1:7866\n",
766
  "\n",
767
+ "To create a public link, set `share=True` in `launch()`.\n"
768
+ ]
769
+ },
770
+ {
771
+ "data": {
772
+ "text/html": [
773
+ "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
774
+ ],
775
+ "text/plain": [
776
+ "<IPython.core.display.HTML object>"
777
+ ]
778
+ },
779
+ "metadata": {},
780
+ "output_type": "display_data"
781
+ },
782
+ {
783
+ "data": {
784
+ "text/plain": []
785
+ },
786
+ "execution_count": 55,
787
+ "metadata": {},
788
+ "output_type": "execute_result"
789
+ },
790
+ {
791
+ "name": "stdout",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "\n",
795
+ "If you do I am sure be your father's limbs,\n",
796
+ "Which is not drouble Paulina,\n",
797
+ "Where I come to me too much be to thus\n",
798
+ "As in a mind sorten but here hoped--followed as hung\n",
799
+ "bard at the time own that that he is death;\n",
800
+ "The stir weors shall be a new good groans.\n",
801
+ "\n",
802
+ "OXFORD:\n",
803
+ "And I say him to deed your city,\n",
804
+ "Thy heart advantage by mine at their world,\n",
805
+ "And be glood much show with upon the suddenly,\n",
806
+ "And in this free can company from here faults.\n",
807
+ "\n",
808
+ "YORK:\n",
809
+ "I bring thee, now I am again: it is a Montague.\n",
810
+ "\n",
811
+ "BRUTUS:\n",
812
+ "We while you not be the fault of the life.\n",
813
+ "\n",
814
+ "KING RICHARD III:\n",
815
+ "Come, sir, do I: in put I with so,\n",
816
+ "The rest redies with this ballad or suppher.\n",
817
+ "\n",
818
+ "Third Servingman:\n",
819
+ "Now, well do with kindness of power the that\n",
820
+ "ence of the ambassage; but to shall here,\n",
821
+ "But saving as the wounds both drawn dead.\n",
822
+ "\n",
823
+ "CLARENCE:\n",
824
+ "Peace, sir, he adieu.\n",
825
+ "\n",
826
+ "MISTRESS OVERDONE:\n",
827
+ "Cheep, comes thou lovest which I do well;\n",
828
+ "Now now as that protectors, in this parliament--\n",
829
+ "That nature i' the mark of her with old her,\n",
830
+ "And so here doth in circle and call'd the wind,\n",
831
+ "If you do I am sure be your father's limbs,\n",
832
+ "Which is not drouble Paulina,\n",
833
+ "Where I come to me too much be to thus\n",
834
+ "As in a mind sorten but here hoped--followed as hung\n",
835
+ "bard at the time own that that he is death;\n",
836
+ "The stir weors shall be a new good groans.\n",
837
+ "\n",
838
+ "OXFORD:\n",
839
+ "And I say him to deed your city,\n",
840
+ "Thy heart advantage by mine at their world,\n",
841
+ "And be glood much show with upon the suddenly,\n",
842
+ "And in this free can company from here faults.\n",
843
+ "\n",
844
+ "YORK:\n",
845
+ "I bring thee, now I am again: it is a Montague.\n",
846
+ "\n",
847
+ "BRUTUS:\n",
848
+ "We while you not be the fault of the life.\n",
849
+ "\n",
850
+ "KING RICHARD III:\n",
851
+ "Come, sir, do I: in put I with so,\n",
852
+ "The rest redies with this ballad or suppher.\n",
853
+ "\n",
854
+ "Third Servingman:\n",
855
+ "Now, well do with kindness of power the that\n",
856
+ "ence of the ambassage; but to shall here,\n",
857
+ "But saving as the wounds both drawn dead.\n",
858
+ "\n",
859
+ "CLARENCE:\n",
860
+ "Peace, sir, he adieu.\n",
861
+ "\n",
862
+ "MISTRESS OVERDONE:\n",
863
+ "Cheep, comes thou lovest which I do well;\n",
864
+ "Now now as that protectors, in this parliament--\n",
865
+ "That nature i' the mark of her with old her,\n",
866
+ "And so here doth in circle and call'd the wind,\n",
867
+ "If you do I am sure be your father's limbs,\n",
868
+ "Which is not drouble Paulina,\n",
869
+ "Where I come to me too much be to thus\n",
870
+ "As in a mind sorten but here hoped--followed as hung\n",
871
+ "bard at the time own that that he is death;\n",
872
+ "The stir weors shall be a new good groans.\n",
873
+ "\n",
874
+ "OXFORD:\n",
875
+ "And I say him to deed your city,\n",
876
+ "Thy heart advantage by mine at their world,\n",
877
+ "And be glood much show with upon the suddenly,\n",
878
+ "And in this free can company from here faults.\n",
879
+ "\n",
880
+ "YORK:\n",
881
+ "I bring thee, now I am again: it is a Montague.\n",
882
+ "\n",
883
+ "BRUTUS:\n",
884
+ "We while you not be the fault of the life.\n",
885
+ "\n",
886
+ "KING RICHARD III:\n",
887
+ "Come, sir, do I: in put I with so,\n",
888
+ "The rest redies with this ballad or suppher.\n",
889
+ "\n",
890
+ "Third Servingman:\n",
891
+ "Now, well do with kindness of power the that\n",
892
+ "ence of the ambassage; but to shall here,\n",
893
+ "But saving as the wounds both drawn dead.\n",
894
+ "\n",
895
+ "CLARENCE:\n",
896
+ "Peace, sir, he adieu.\n",
897
+ "\n",
898
+ "MISTRESS OVERDONE:\n",
899
+ "Cheep, comes thou lovest which I do well;\n",
900
+ "Now now as that protectors, in this parliament--\n",
901
+ "That nature i' the mark of her with old her,\n",
902
+ "And so here doth in circle and call'd the wind,\n",
903
+ "If you do I am sure be your father's limbs,\n",
904
+ "Which is not drouble Paulina,\n",
905
+ "Where I come to me too much be to thus\n",
906
+ "As in a mind sorten but here hoped--followed as hung\n",
907
+ "bard at the time own that that he is death;\n",
908
+ "The stir weors shall be a new good groans.\n",
909
+ "\n",
910
+ "OXFORD:\n",
911
+ "And I say him to deed your city,\n",
912
+ "Thy heart advantage by mine at their world,\n",
913
+ "And be glood much show with upon the suddenly,\n",
914
+ "And in this free can company from here faults.\n",
915
+ "\n",
916
+ "YORK:\n",
917
+ "I bring thee, now I am again: it is a Montague.\n",
918
+ "\n",
919
+ "BRUTUS:\n",
920
+ "We while you not be the fault of the life.\n",
921
+ "\n",
922
+ "KING RICHARD III:\n",
923
+ "Come, sir, do I: in put I with so,\n",
924
+ "The rest redies with this ballad or suppher.\n",
925
+ "\n",
926
+ "Third Servingman:\n",
927
+ "Now, well do with kindness of power the that\n",
928
+ "ence of the ambassage; but to shall here,\n",
929
+ "But saving as the wounds both drawn dead.\n",
930
+ "\n",
931
+ "CLARENCE:\n",
932
+ "Peace, sir, he adieu.\n",
933
+ "\n",
934
+ "MISTRESS OVERDONE:\n",
935
+ "Cheep, comes thou lovest which I do well;\n",
936
+ "Now now as that protectors, in this parliament--\n",
937
+ "That nature i' the mark of her with old her,\n",
938
+ "And so here doth in circle and call'd the wind,\n",
939
+ "If you do I am sure be your father's limbs,\n",
940
+ "Which is not drouble Paulina,\n",
941
+ "Where I come to me too much be to thus\n",
942
+ "As in a mind sorten but here hoped--followed as hung\n",
943
+ "bard at the time own that that he is death;\n",
944
+ "The stir weors shall be a new good groans.\n",
945
+ "\n",
946
+ "OXFORD:\n",
947
+ "And I say him to deed your city,\n",
948
+ "Thy heart advantage by mine at their world,\n",
949
+ "And be glood much show with upon the suddenly,\n",
950
+ "And in this free can company from here faults.\n",
951
+ "\n",
952
+ "YORK:\n",
953
+ "I bring thee, now I am again: it is a Montague.\n",
954
+ "\n",
955
+ "BRUTUS:\n",
956
+ "We while you not be the fault of the life.\n",
957
+ "\n",
958
+ "KING RICHARD III:\n",
959
+ "Come, sir, do I: in put I with so,\n",
960
+ "The rest redies with this ballad or suppher.\n",
961
+ "\n",
962
+ "Third Servingman:\n",
963
+ "Now, well do with kindness of power the that\n",
964
+ "ence of the ambassage; but to shall here,\n",
965
+ "But saving as the wounds both drawn dead.\n",
966
+ "\n",
967
+ "CLARENCE:\n",
968
+ "Peace, sir, he adieu.\n",
969
+ "\n",
970
+ "MISTRESS OVERDONE:\n",
971
+ "Cheep, comes thou lovest which I do well;\n",
972
+ "Now now as that protectors, in this parliament--\n",
973
+ "That nature i' the mark of her with old her,\n",
974
+ "And so here doth in circle and call'd the wind,\n",
975
+ "If you do I have some that to be blow.\n",
976
+ "\n",
977
+ "QUEEN ELIZABETH:\n",
978
+ "But shall we will not a man; for all one,\n",
979
+ "That you will not be but a fault and house,\n",
980
+ "I cannot strive at once from your father's son.\n",
981
  "\n",
982
  "GLOUCESTER:\n",
983
+ "\n",
984
+ "CLARENCE:\n",
985
+ "That is the field to be put in his name.\n",
986
+ "\n",
987
+ "KING RICHARD III:\n",
988
+ "Ay, sir.\n",
989
+ "\n",
990
+ "QUEEN ELIZABETH:\n",
991
+ "How farewell.\n",
992
+ "\n",
993
+ "POLIXENES:\n",
994
+ "She hath stand and better me for you.\n",
995
+ "\n",
996
+ "QUEEN ELIZABETH:\n",
997
+ "I will be patient me.\n",
998
+ "\n",
999
+ "CLARENCE:\n",
1000
+ "The been are the lawful strength, and he made\n",
1001
+ "That me was upon the which he comes with a holy face.\n",
1002
  "\n",
1003
  "BUCKINGHAM:\n",
1004
+ "The noble and of my tent.\n",
1005
+ "\n",
1006
+ "GLOUCESTER:\n",
1007
+ "My gracious lord, he did soul the news,\n",
1008
+ "Which he had of pretty with the fair traitor\n",
1009
+ "Than a service hath some said for his prince.\n",
1010
+ "\n",
1011
+ "QUEEN MARGARET:\n",
1012
+ "Stay the sad is mine execution,\n",
1013
+ "Which can slain with you have some so death;\n",
1014
+ "For I have sent done the many of the city\n",
1015
+ "That you more so much hold me speak.\n",
1016
+ "\n",
1017
+ "GLOUCESTER:\n",
1018
+ "\n",
1019
+ "GLOUCESTER:\n",
1020
+ "So doth fair lord, that be straight in his wit\n",
1021
+ "The matter, the city distance.\n",
1022
+ "\n",
1023
+ "LEONTES:\n",
1024
+ "The brother about the fire--thou sha\n",
1025
+ "If you do I am sure be so princely false;\n",
1026
+ "And the prince may do me for the house.\n",
1027
+ "\n",
1028
+ "GLOUCESTER:\n",
1029
+ "Why, what wilt me, when we strike a mine eyes;\n",
1030
+ "Of he doth deficers, in my fair deeds,\n",
1031
+ "That thou wert from me hence his pardon lives.\n",
1032
+ "\n",
1033
+ "LORD FITZWATER:\n",
1034
+ "Say, I will be ground to some of my body;\n",
1035
+ "And what thou wilt part so supposed with him.\n",
1036
+ "\n",
1037
+ "CLARENCE:\n",
1038
+ "You might be breathe in me?\n",
1039
+ "\n",
1040
+ "CAMILLO:\n",
1041
+ "Stay now the first is no more resolved\n",
1042
+ "With all the lawful and shed the gates word.\n",
1043
+ "\n",
1044
+ "CLARENCE:\n",
1045
+ "Thou that, when my fair that I would not for heaven,\n",
1046
+ "That fortune but not like her for the mind,\n",
1047
+ "When I do be my tongues and her love.\n",
1048
+ "\n",
1049
+ "ANGELO:\n",
1050
+ "She shall we not have done by his house.\n",
1051
+ "\n",
1052
+ "HENRY BOLINGBROKE:\n",
1053
+ "Have you no more by my fortune should as heard.\n",
1054
+ "\n",
1055
+ "QUEEN ELIZABETH:\n",
1056
+ "O, bring forth thy brother Gloucester,\n",
1057
+ "But both an earthly of love from spoke me.\n",
1058
+ "\n",
1059
+ "QUEEN ELIZABETH:\n",
1060
+ "Thou art thou as you shalt be dish'd in their head.\n",
1061
+ "\n",
1062
+ "Second Murderer:\n",
1063
+ "The prince and my execution and heard it.\n",
1064
+ "\n",
1065
+ "KING RICHARD III:\n",
1066
+ "Say, well I am a house in my heart;\n",
1067
+ "But I am"
1068
  ]
1069
  }
1070
  ],
1071
  "source": [
1072
+ "import gradio as gr\n",
1073
  "\n",
1074
+ "demo = gr.Interface(\n",
1075
+ " fn=generate, \n",
1076
+ " inputs=[\n",
1077
+ " gr.Textbox(lines=2, placeholder=\"Prompt here...\"),\n",
1078
+ " gr.Number(value=256),\n",
1079
+ " gr.Number(value=0.8),\n",
1080
+ " gr.Slider(maximum=128,value=10),\n",
1081
+ " gr.Slider(maximum=1,value=1)\n",
1082
+ " ], \n",
1083
+ " outputs=\"text\",\n",
1084
+ " title=\"Shakespeare-GPT\",\n",
1085
+ " description=\"Putting theater kids out of their nonexistent jobs since 2023\"\n",
1086
+ ")\n",
1087
  "\n",
1088
+ "demo.launch()\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1089
  ]
1090
  }
1091
  ],
gpt.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import gc
4
+ from torch.utils.tensorboard import SummaryWriter
5
+ import os
6
+ import json
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
12
+ import pandas as pd
13
+ import numpy as np
14
+ import math
15
+ from typing import Type
16
+
17
+
18
+ # Define the command-line arguments
19
+ parser = argparse.ArgumentParser(description='GPT CLI')
20
+ parser.add_argument('--gui', action='store_true', help='Enable Gradio UI mode')
21
+ parser.add_argument('--config', default='./gpt_config.json',
22
+ help='Path to the config file')
23
+ subparsers = parser.add_subparsers(dest='command', help='Choose a command')
24
+
25
+ # Define the training command
26
+ train_parser = subparsers.add_parser('train', help='Train the model')
27
+ train_parser.add_argument('--load-from-restore', action='store_true',
28
+ help='Load from restore path instead of training from scratch')
29
+
30
+ # Define the evaluation command
31
+ eval_parser = subparsers.add_parser('eval', help='Evaluate the model')
32
+ eval_parser.add_argument('--data', default='./data/evaluation_data.txt',
33
+ help='Path to the evaluation data file')
34
+
35
+ # Define the inference command
36
+ infer_parser = subparsers.add_parser(
37
+ 'infer', help='Generate text from the model')
38
+ infer_parser.add_argument('--text', type=str, required=True,
39
+ help='Input text for generating continuation')
40
+ infer_parser.add_argument('--length', type=int,
41
+ default=100, help='Number of characters to generate')
42
+
43
+ torch.manual_seed(1337)
44
+ # Set device to CUDA if available, otherwise use CPU
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+
48
+ class GPTConfig:
49
+ def __init__(self, config_file_path):
50
+ with open(config_file_path, 'r') as f:
51
+ config = json.load(f)
52
+
53
+ # Architecture configuration
54
+ architecture_config = config['architecture']
55
+ self.embedding_dim = architecture_config['embedding_dim']
56
+ self.vocab_size = architecture_config['vocab_size']
57
+ self.context_size = architecture_config['context_size']
58
+ self.num_heads = architecture_config['num_heads']
59
+ self.num_layers = architecture_config['num_layers']
60
+
61
+ # Training configuration
62
+ training_config = config['training']
63
+ self.batch_size = training_config['batch_size']
64
+ self.training_data_path = training_config['training_data_path']
65
+ self.learning_rate = training_config['learning_rate']
66
+ self.num_steps = training_config['num_steps']
67
+ self.val_interval = training_config['val_interval']
68
+
69
+ # Checkpoint restore configuration
70
+ self.restore_path = config['restore_path']
71
+
72
+
73
+ def encode_text(text):
74
+ # Simple dumb ASCII character-level "encoding" since all training data is ASCII.
75
+ # Consider better tokenization if moving off character-level
76
+ return ([ord(t) for t in text])
77
+
78
+
79
+ def decode_text(indices):
80
+ return ([chr(x) for x in indices])
81
+
82
+
83
+ class TextDataset(Dataset):
84
+ def __init__(self, data_tensor, context_size):
85
+ self.data_tensor = data_tensor
86
+ self.context_size = context_size
87
+
88
+ def __len__(self):
89
+ return len(self.data_tensor) - self.context_size
90
+
91
+ def __getitem__(self, index):
92
+ x = self.data_tensor[index:index + self.context_size]
93
+ y = self.data_tensor[index + 1:index + self.context_size + 1]
94
+
95
+ return x, y
96
+
97
+
98
+ def load_dataset(data_path, val, context_size):
99
+ with open(data_path, 'r', encoding='utf-8') as f:
100
+ text = f.read()
101
+
102
+ # Tensorify data, put it in dataset
103
+ data = torch.tensor(encode_text(text), dtype=torch.int32)
104
+
105
+ test_split_idx = int(0.8 * len(data))
106
+ val_split_idx = int(0.9 * len(data))
107
+ train_data = data[:test_split_idx]
108
+ test_data = data[test_split_idx:val_split_idx]
109
+ val_data = data[val_split_idx:]
110
+ # print(f"{len(data)} chars of data")
111
+
112
+ train_dataset = TextDataset(train_data, context_size)
113
+ test_dataset = TextDataset(test_data, context_size)
114
+ val_dataset = TextDataset(test_data, context_size)
115
+ return ((train_dataset, test_dataset, val_dataset))
116
+
117
+
118
+ class MultiheadAttention(nn.Module):
119
+ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, device=None, dtype=None):
120
+ super(MultiheadAttention, self).__init__()
121
+
122
+ # Save variables
123
+ self.embed_dim = embed_dim
124
+ self.num_heads = num_heads
125
+ self.d_k = embed_dim // num_heads
126
+
127
+ self.Q = nn.Linear(embed_dim, embed_dim, bias=False)
128
+ self.K = nn.Linear(embed_dim, embed_dim, bias=False)
129
+ self.V = nn.Linear(embed_dim, embed_dim, bias=False)
130
+
131
+ self.dropout = nn.Dropout(dropout)
132
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
133
+
134
+ def forward(self, query, key, value, attn_mask=None):
135
+ batch_size = query.size(0)
136
+
137
+ # Apply linear layers
138
+ q = self.Q(query) # [B, C, E]
139
+ k = self.K(key) # [B, C, E]
140
+ v = self.V(value) # [B, C, E]
141
+
142
+ # Mutate dimensions so the attention matmul can get rid of the inner d_k
143
+ # [batch_size, num_heads, C, d_k]
144
+ q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
145
+ # [batch_size, num_heads, C, d_k]
146
+ k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
147
+ # [batch_size, num_heads, C, d_k]
148
+ v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
149
+
150
+ # Get raw attention scores
151
+ scores = torch.matmul(q, k.transpose(-2, -1)) / \
152
+ math.sqrt(self.d_k) # [B, num_heads, C, C]
153
+
154
+ # Apply mask, if necessary
155
+ if attn_mask is not None:
156
+ scores = scores.masked_fill(attn_mask, float('-inf'))
157
+
158
+ # Scale by sqrt(k)
159
+ attn = F.softmax(scores, dim=-1)
160
+ attn = self.dropout(attn)
161
+ out = attn @ v # [B, num_heads, C, d_k]
162
+
163
+ # Concat and project
164
+ # Swap C and num_heads, force memory to coalesce, then fuse back num_heads and d_k together
165
+ out = out.transpose(1, 2).contiguous().view(
166
+ batch_size, -1, self.embed_dim)
167
+ # Project: give attention "time to think". Maybe this should be part of a different module but whatever
168
+ out = self.out_proj(out)
169
+ return ((out, None))
170
+
171
+
172
+ class FeedForward(nn.Module):
173
+ def __init__(self, embed_dim, dropout):
174
+ super().__init__()
175
+ self.net = nn.Sequential(
176
+ nn.Linear(embed_dim, 4 * embed_dim),
177
+ nn.GELU(),
178
+ nn.Linear(4 * embed_dim, embed_dim),
179
+ nn.Dropout(dropout),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return (self.net(x))
184
+
185
+
186
+ class Block(nn.Module):
187
+ """Self-attention"""
188
+
189
+ def __init__(self, embed_dim, num_heads, mask, dropout=0.2):
190
+ super(Block, self).__init__()
191
+ self.register_buffer("mask", mask)
192
+ self.head = MultiheadAttention(
193
+ embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
194
+ # self.head = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
195
+ self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)
196
+ self.ln1 = nn.LayerNorm(embed_dim)
197
+ self.ln2 = nn.LayerNorm(embed_dim)
198
+
199
+ def forward(self, x):
200
+ # Residual connections
201
+ x = self.ln1(x)
202
+ attn_output, _ = self.head(x, x, x, attn_mask=self.mask)
203
+ x = x + attn_output
204
+ out = x + self.ffwd(self.ln2(x))
205
+ return out
206
+
207
+
208
+ class GPT(nn.Module):
209
+ def __init__(self, embedding_dim, vocab_size, context_size):
210
+ super(GPT, self).__init__()
211
+
212
+ self.embedding_dim = embedding_dim
213
+ self.output_dim = vocab_size
214
+ self.context_size = context_size
215
+
216
+ NUM_HEADS = 4
217
+ NUM_LAYERS = 4
218
+
219
+ # Initialize layers
220
+ self.tok_embed = nn.Embedding(vocab_size, embedding_dim)
221
+ self.pos_embed = nn.Embedding(context_size, embedding_dim)
222
+
223
+ mask = torch.tril(torch.ones(
224
+ self.context_size, self.context_size)).bool()
225
+ mask = ~mask
226
+ self.register_buffer("mask", mask)
227
+
228
+ self.blocks = nn.Sequential(
229
+ *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)]
230
+ )
231
+
232
+ self.ln_f = nn.LayerNorm(self.embedding_dim)
233
+ # Final feed-forward layer from embeddings
234
+ self.ffwd = nn.Linear(
235
+ embedding_dim, out_features=vocab_size, bias=False)
236
+
237
+ def forward(self, x):
238
+ tok_embed = self.tok_embed(x)
239
+ pos_embed = self.pos_embed(
240
+ torch.arange(0, self.context_size, device="cuda")
241
+ )
242
+ x = tok_embed + pos_embed
243
+
244
+ x = self.blocks(x)
245
+ x = self.ln_f(x)
246
+
247
+ logits = self.ffwd(x)
248
+ return (logits)
249
+
250
+ def infer(self, x):
251
+ with torch.no_grad():
252
+ self.eval()
253
+ res = self.forward(x)
254
+ return (res)
255
+
256
+ def num_params(self):
257
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
258
+
259
+
260
+ def load_checkpoint(model, optimizer, path):
261
+ """
262
+ Loads a saved checkpoint file into the model and optimizer.
263
+
264
+ Args:
265
+ model (nn.Module): The PyTorch model to load the checkpoint into.
266
+ optimizer (torch.optim.Optimizer): The PyTorch optimizer to load the checkpoint into.
267
+ path (str): The path to the saved checkpoint file.
268
+
269
+ Returns:
270
+ Tuple[nn.Module, torch.optim.Optimizer]: The model and optimizer, loaded with the checkpoint state.
271
+ """
272
+ checkpoint = torch.load(path)
273
+ model.load_state_dict(checkpoint['model_state_dict'])
274
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
275
+ return (model, optimizer)
276
+
277
+
278
+ def save_checkpoint(model, optimizer, path, steps):
279
+ """
280
+ Saves a checkpoint of the model and optimizer to disk.
281
+
282
+ Args:
283
+ model (nn.Module): The PyTorch model to save the checkpoint of.
284
+ optimizer (torch.optim.Optimizer): The PyTorch optimizer to save the checkpoint of.
285
+ path (str): The path to save the checkpoint file.
286
+ steps (int): The number of training steps that have been completed.
287
+
288
+ Returns:
289
+ None
290
+ """
291
+ torch.save({
292
+ 'steps': steps,
293
+ 'model_state_dict': model.state_dict(),
294
+ 'optimizer_state_dict': optimizer.state_dict(),
295
+ }, path)
296
+
297
+
298
+ def compute_loss(model, criterion, x, y):
299
+ logits = model(x)
300
+ B, C, V = logits.shape
301
+ logits = logits.view(B*C, V)
302
+ y = y.view(B*C)
303
+ loss = F.cross_entropy(logits, y.long())
304
+ return loss
305
+
306
+
307
+ def train(model, optimizer, config: Type[GPTConfig]):
308
+ model = model.to(device)
309
+ criterion = F.cross_entropy
310
+
311
+ global_step = 0
312
+
313
+ train_dataset, val_dataset = load_dataset(
314
+ config.training_data_path, None, model.context_size)
315
+
316
+ train_dataloader = DataLoader(
317
+ train_dataset,
318
+ batch_size=config.batch_size,
319
+ shuffle=True,
320
+ num_workers=4
321
+ )
322
+
323
+ test_dataloader = DataLoader(
324
+ test_dataset, batch_size=512, num_workers=4, shuffle=True)
325
+
326
+ model.train()
327
+
328
+ EPOCHS = 1
329
+ STEPS = config.num_steps
330
+ VAL_INTERVAL = 100
331
+
332
+ writer = SummaryWriter()
333
+
334
+ step = 0
335
+
336
+ for epoch in range(EPOCHS):
337
+ for data, target in train_dataloader:
338
+ data = data.to(device)
339
+ target = target.to(device)
340
+
341
+ loss = compute_loss(model, criterion, data, target)
342
+
343
+ # Backward pass
344
+ optimizer.zero_grad()
345
+ loss.backward()
346
+ optimizer.step()
347
+
348
+ writer.add_scalar(
349
+ "Loss/train", loss.cpu().detach().numpy(), global_step)
350
+ global_step += 1
351
+
352
+ if step % VAL_INTERVAL == 0:
353
+ total_loss = 0
354
+ total_samples = 0
355
+
356
+ with torch.no_grad():
357
+ model.eval()
358
+ for x, y in test_dataloader:
359
+ x = x.to(device)
360
+ y = y.to(device)
361
+
362
+ batch_loss = compute_loss(model, criterion, x, y)
363
+ total_loss += batch_loss.item() * 512
364
+ total_samples += 512
365
+ if total_samples > 10:
366
+ break
367
+
368
+ model.train()
369
+ average_loss = total_loss / total_samples
370
+
371
+ print(f"Step {step}; loss: {average_loss}")
372
+ writer.add_scalar("Loss/val", average_loss, global_step)
373
+
374
+ step += 1
375
+ if step >= STEPS:
376
+ break
377
+
378
+ writer.close()
379
+
380
+
381
+ def evaluate_model(model, val_dataset, block_size=512, max_samples=100000):
382
+ model.eval()
383
+ total_loss = 0.0
384
+ total_samples = 0
385
+ criterion = F.cross_entropy
386
+
387
+ val_dataloader = DataLoader(
388
+ val_dataset, batch_size=block_size, num_workers=4)
389
+ with torch.no_grad():
390
+ for inputs, targets in val_dataloader:
391
+ inputs = inputs.to(device)
392
+ targets = targets.to(device)
393
+
394
+ batch_loss = compute_loss(model, criterion, inputs, targets)
395
+ total_loss += batch_loss.item() * inputs.size(0)
396
+ total_samples += inputs.size(0)
397
+ if total_samples > max_samples:
398
+ break
399
+
400
+ average_loss = total_loss / total_samples
401
+ return average_loss
402
+
403
+
404
+ def generate(model, config, prompt, gen_length, temp=1, top_k=10, top_p=None):
405
+ g_cuda = torch.Generator(device=device)
406
+ contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device)
407
+
408
+ model.eval()
409
+ for i in range(gen_length):
410
+ transform = nn.LogSoftmax(1)
411
+ x = contexts[-config.context_size:]
412
+ if x.size(0) < config.context_size:
413
+ x = F.pad(x, (config.context_size - x.size(0), 0),
414
+ "constant", 0).unsqueeze(0) # B*T
415
+ else:
416
+ x = x.unsqueeze(0)
417
+
418
+ preds = model.infer(x)
419
+ preds = preds.squeeze(0)
420
+ preds = preds / temp
421
+ probs = F.softmax(preds, dim=-1)
422
+
423
+ if top_p is not None:
424
+ # Apply top-p
425
+ sorted_probs, sorted_indices = torch.sort(
426
+ probs[-1, :], descending=True)
427
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
428
+ # find cutoff
429
+ idx_top_p = (cumulative_probs < top_p).sum().item()
430
+ top_probs = sorted_probs[:idx_top_p]
431
+ top_indices = sorted_indices[:idx_top_p]
432
+ # Null case
433
+ if top_probs.size(0) == 0:
434
+ top_probs = sorted_probs[:1]
435
+ top_indices = sorted_indices[:1]
436
+
437
+ next_char = torch.multinomial(
438
+ top_probs, num_samples=1, generator=g_cuda)
439
+ next_char = top_indices[next_char]
440
+ elif top_k is not None:
441
+ top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k)
442
+ next_char = torch.multinomial(
443
+ top_k_probs, num_samples=1, generator=g_cuda)
444
+ next_char = top_k_indices[next_char]
445
+ else:
446
+ next_char = torch.multinomial(
447
+ probs, num_samples=1, generator=g_cuda)
448
+
449
+ contexts = torch.cat((contexts, next_char), dim=0)
450
+ print(decode_text(next_char.cpu().numpy())[-1], end="")
451
+
452
+ return ("".join(decode_text(contexts.cpu().numpy())))
453
+
454
+
455
+ def main():
456
+ # Parse the command-line arguments
457
+ args = parser.parse_args()
458
+
459
+ config = GPTConfig(args.config)
460
+ # Create the GPT model
461
+ model = GPT(
462
+ vocab_size=config.vocab_size,
463
+ context_size=config.context_size,
464
+ embedding_dim=config.embedding_dim
465
+ )
466
+ model.to(device)
467
+
468
+ optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
469
+ if args.gui:
470
+ load_checkpoint(model, optimizer, config.restore_path)
471
+ demo = gr.Interface(
472
+ fn=lambda *args: generate(model, config, *args),
473
+ inputs=[
474
+ gr.Textbox(lines=2, placeholder="Prompt here..."),
475
+ gr.Number(precision=0, value=256),
476
+ gr.Number(value=0.8),
477
+ gr.Slider(maximum=128, value=10),
478
+ gr.Slider(maximum=1, value=1)
479
+ ],
480
+ outputs="text",
481
+ title="Shakespeare-GPT",
482
+ description="Putting theater kids out of their nonexistent jobs since 2023"
483
+ )
484
+
485
+ demo.launch()
486
+ elif args.command == "train":
487
+ if args.load_from_restore:
488
+ load_checkpoint(model, optimizer, path)
489
+
490
+ train(model, config)
491
+ elif args.command == "eval":
492
+ _, _, test_dataset = load_dataset(
493
+ config.training_data_path, None, model.context_size)
494
+ evaluate_model(model, test_dataset)
495
+ elif args.command == "infer":
496
+ prompt = args.text
497
+ generated_text = generate(model, config, prompt, args.length)
498
+ print(generated_text)
499
+
500
+
501
+ if __name__ == "__main__":
502
+ main()
gpt_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": {
3
+ "embedding_dim": 256,
4
+ "vocab_size": 128,
5
+ "context_size": 256,
6
+ "num_heads": 4,
7
+ "num_layers": 4
8
+ },
9
+ "training": {
10
+ "batch_size": 64,
11
+ "training_data_path": "datasets/training_data.txt",
12
+ "learning_rate": 3e-4,
13
+ "num_steps": 5000,
14
+ "val_interval": 100
15
+ },
16
+ "restore_path": "checkpoints/model.pt"
17
+ }