PrarthanaTS commited on
Commit
ec5288c
·
1 Parent(s): 233c68c

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ news.csv filter=lfs diff=lfs merge=lfs -text
GPT_Shakespeare.ipynb ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "gpuType": "A100"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 7,
23
+ "metadata": {
24
+ "colab": {
25
+ "base_uri": "https://localhost:8080/"
26
+ },
27
+ "id": "G6BvseJ-0VwS",
28
+ "outputId": "72cafdf8-dd7b-4412-a9bc-2cfcebfb6949"
29
+ },
30
+ "outputs": [
31
+ {
32
+ "output_type": "stream",
33
+ "name": "stdout",
34
+ "text": [
35
+ "--2023-11-03 11:10:34-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
36
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
37
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
38
+ "HTTP request sent, awaiting response... 200 OK\n",
39
+ "Length: 1115394 (1.1M) [text/plain]\n",
40
+ "Saving to: ‘input.txt’\n",
41
+ "\n",
42
+ "input.txt 100%[===================>] 1.06M --.-KB/s in 0.02s \n",
43
+ "\n",
44
+ "2023-11-03 11:10:35 (48.3 MB/s) - ‘input.txt’ saved [1115394/1115394]\n",
45
+ "\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "source": [
56
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
57
+ " text = f.read()"
58
+ ],
59
+ "metadata": {
60
+ "id": "pxZym4QU1mCq"
61
+ },
62
+ "execution_count": 11,
63
+ "outputs": []
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "source": [
68
+ "import torch\n",
69
+ "import torch.nn as nn\n",
70
+ "from torch.nn import functional as F\n",
71
+ "\n",
72
+ "# hyperparameters\n",
73
+ "batch_size = 16 # how many independent sequences will we process in parallel?\n",
74
+ "block_size = 32 # what is the maximum context length for predictions?\n",
75
+ "max_iters = 5000\n",
76
+ "eval_interval = 100\n",
77
+ "learning_rate = 1e-3\n",
78
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
79
+ "eval_iters = 200\n",
80
+ "n_embd = 64\n",
81
+ "n_head = 4\n",
82
+ "n_layer = 4\n",
83
+ "dropout = 0.0\n",
84
+ "\n",
85
+ "torch.manual_seed(1337)\n",
86
+ "\n",
87
+ "\n",
88
+ "# here are all the unique characters that occur in this text\n",
89
+ "chars = sorted(list(set(text)))\n",
90
+ "vocab_size = len(chars)\n",
91
+ "# create a mapping from characters to integers\n",
92
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
93
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
94
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
95
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
96
+ "\n",
97
+ "# Train and test splits\n",
98
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
99
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
100
+ "train_data = data[:n]\n",
101
+ "val_data = data[n:]\n",
102
+ "\n",
103
+ "# data loading\n",
104
+ "def get_batch(split):\n",
105
+ " # generate a small batch of data of inputs x and targets y\n",
106
+ " data = train_data if split == 'train' else val_data\n",
107
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
108
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
109
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
110
+ " x, y = x.to(device), y.to(device)\n",
111
+ " return x, y\n",
112
+ "\n",
113
+ "@torch.no_grad()\n",
114
+ "def estimate_loss():\n",
115
+ " out = {}\n",
116
+ " model.eval()\n",
117
+ " for split in ['train', 'val']:\n",
118
+ " losses = torch.zeros(eval_iters)\n",
119
+ " for k in range(eval_iters):\n",
120
+ " X, Y = get_batch(split)\n",
121
+ " logits, loss = model(X, Y)\n",
122
+ " losses[k] = loss.item()\n",
123
+ " out[split] = losses.mean()\n",
124
+ " model.train()\n",
125
+ " return out\n",
126
+ "\n",
127
+ "class Head(nn.Module):\n",
128
+ " \"\"\" one head of self-attention \"\"\"\n",
129
+ "\n",
130
+ " def __init__(self, head_size):\n",
131
+ " super().__init__()\n",
132
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
133
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
134
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
135
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
136
+ "\n",
137
+ " self.dropout = nn.Dropout(dropout)\n",
138
+ "\n",
139
+ " def forward(self, x):\n",
140
+ " B,T,C = x.shape\n",
141
+ " k = self.key(x) # (B,T,C)\n",
142
+ " q = self.query(x) # (B,T,C)\n",
143
+ " # compute attention scores (\"affinities\")\n",
144
+ " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
145
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
146
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
147
+ " wei = self.dropout(wei)\n",
148
+ " # perform the weighted aggregation of the values\n",
149
+ " v = self.value(x) # (B,T,C)\n",
150
+ " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
151
+ " return out\n",
152
+ "\n",
153
+ "class MultiHeadAttention(nn.Module):\n",
154
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
155
+ "\n",
156
+ " def __init__(self, num_heads, head_size):\n",
157
+ " super().__init__()\n",
158
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
159
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
160
+ " self.dropout = nn.Dropout(dropout)\n",
161
+ "\n",
162
+ " def forward(self, x):\n",
163
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
164
+ " out = self.dropout(self.proj(out))\n",
165
+ " return out\n",
166
+ "\n",
167
+ "class FeedFoward(nn.Module):\n",
168
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
169
+ "\n",
170
+ " def __init__(self, n_embd):\n",
171
+ " super().__init__()\n",
172
+ " self.net = nn.Sequential(\n",
173
+ " nn.Linear(n_embd, 4 * n_embd),\n",
174
+ " nn.ReLU(),\n",
175
+ " nn.Linear(4 * n_embd, n_embd),\n",
176
+ " nn.Dropout(dropout),\n",
177
+ " )\n",
178
+ "\n",
179
+ " def forward(self, x):\n",
180
+ " return self.net(x)\n",
181
+ "\n",
182
+ "class Block(nn.Module):\n",
183
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
184
+ "\n",
185
+ " def __init__(self, n_embd, n_head):\n",
186
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
187
+ " super().__init__()\n",
188
+ " head_size = n_embd // n_head\n",
189
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
190
+ " self.ffwd = FeedFoward(n_embd)\n",
191
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
192
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
193
+ "\n",
194
+ " def forward(self, x):\n",
195
+ " x = x + self.sa(self.ln1(x))\n",
196
+ " x = x + self.ffwd(self.ln2(x))\n",
197
+ " return x\n",
198
+ "\n",
199
+ "# super simple bigram model\n",
200
+ "class BigramLanguageModel(nn.Module):\n",
201
+ "\n",
202
+ " def __init__(self):\n",
203
+ " super().__init__()\n",
204
+ " # each token directly reads off the logits for the next token from a lookup table\n",
205
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
206
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
207
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
208
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
209
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
210
+ "\n",
211
+ " def forward(self, idx, targets=None):\n",
212
+ " B, T = idx.shape\n",
213
+ "\n",
214
+ " # idx and targets are both (B,T) tensor of integers\n",
215
+ " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
216
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
217
+ " x = tok_emb + pos_emb # (B,T,C)\n",
218
+ " x = self.blocks(x) # (B,T,C)\n",
219
+ " x = self.ln_f(x) # (B,T,C)\n",
220
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
221
+ "\n",
222
+ " if targets is None:\n",
223
+ " loss = None\n",
224
+ " else:\n",
225
+ " B, T, C = logits.shape\n",
226
+ " logits = logits.view(B*T, C)\n",
227
+ " targets = targets.view(B*T)\n",
228
+ " loss = F.cross_entropy(logits, targets)\n",
229
+ "\n",
230
+ " return logits, loss\n",
231
+ "\n",
232
+ " def generate(self, idx, max_new_tokens):\n",
233
+ " # idx is (B, T) array of indices in the current context\n",
234
+ " for _ in range(max_new_tokens):\n",
235
+ " # crop idx to the last block_size tokens\n",
236
+ " idx_cond = idx[:, -block_size:]\n",
237
+ " # get the predictions\n",
238
+ " logits, loss = self(idx_cond)\n",
239
+ " # focus only on the last time step\n",
240
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
241
+ " # apply softmax to get probabilities\n",
242
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
243
+ " # sample from the distribution\n",
244
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
245
+ " # append sampled index to the running sequence\n",
246
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
247
+ " return idx\n",
248
+ "\n",
249
+ "model = BigramLanguageModel()\n",
250
+ "m = model.to(device)\n",
251
+ "# print the number of parameters in the model\n",
252
+ "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
253
+ "\n",
254
+ "# create a PyTorch optimizer\n",
255
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
256
+ "\n",
257
+ "for iter in range(max_iters):\n",
258
+ "\n",
259
+ " # every once in a while evaluate the loss on train and val sets\n",
260
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
261
+ " losses = estimate_loss()\n",
262
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
263
+ "\n",
264
+ " # sample a batch of data\n",
265
+ " xb, yb = get_batch('train')\n",
266
+ "\n",
267
+ " # evaluate the loss\n",
268
+ " logits, loss = model(xb, yb)\n",
269
+ " optimizer.zero_grad(set_to_none=True)\n",
270
+ " loss.backward()\n",
271
+ " optimizer.step()\n",
272
+ "\n",
273
+ "# generate from the model\n",
274
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
275
+ "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n"
276
+ ],
277
+ "metadata": {
278
+ "colab": {
279
+ "base_uri": "https://localhost:8080/"
280
+ },
281
+ "id": "U_mrE9Vd10Ab",
282
+ "outputId": "f443c391-fd7f-4c2d-dd0d-72ef25849ef6"
283
+ },
284
+ "execution_count": 12,
285
+ "outputs": [
286
+ {
287
+ "output_type": "stream",
288
+ "name": "stdout",
289
+ "text": [
290
+ "0.209729 M parameters\n",
291
+ "step 0: train loss 4.4116, val loss 4.4022\n",
292
+ "step 100: train loss 2.6568, val loss 2.6670\n",
293
+ "step 200: train loss 2.5091, val loss 2.5060\n",
294
+ "step 300: train loss 2.4199, val loss 2.4337\n",
295
+ "step 400: train loss 2.3500, val loss 2.3563\n",
296
+ "step 500: train loss 2.2961, val loss 2.3126\n",
297
+ "step 600: train loss 2.2408, val loss 2.2501\n",
298
+ "step 700: train loss 2.2053, val loss 2.2187\n",
299
+ "step 800: train loss 2.1636, val loss 2.1870\n",
300
+ "step 900: train loss 2.1226, val loss 2.1483\n",
301
+ "step 1000: train loss 2.1017, val loss 2.1283\n",
302
+ "step 1100: train loss 2.0683, val loss 2.1174\n",
303
+ "step 1200: train loss 2.0376, val loss 2.0798\n",
304
+ "step 1300: train loss 2.0256, val loss 2.0645\n",
305
+ "step 1400: train loss 1.9919, val loss 2.0362\n",
306
+ "step 1500: train loss 1.9696, val loss 2.0304\n",
307
+ "step 1600: train loss 1.9625, val loss 2.0470\n",
308
+ "step 1700: train loss 1.9402, val loss 2.0119\n",
309
+ "step 1800: train loss 1.9085, val loss 1.9957\n",
310
+ "step 1900: train loss 1.9080, val loss 1.9869\n",
311
+ "step 2000: train loss 1.8834, val loss 1.9941\n",
312
+ "step 2100: train loss 1.8727, val loss 1.9758\n",
313
+ "step 2200: train loss 1.8585, val loss 1.9622\n",
314
+ "step 2300: train loss 1.8537, val loss 1.9503\n",
315
+ "step 2400: train loss 1.8419, val loss 1.9424\n",
316
+ "step 2500: train loss 1.8153, val loss 1.9407\n",
317
+ "step 2600: train loss 1.8267, val loss 1.9374\n",
318
+ "step 2700: train loss 1.8126, val loss 1.9344\n",
319
+ "step 2800: train loss 1.8054, val loss 1.9230\n",
320
+ "step 2900: train loss 1.8045, val loss 1.9339\n",
321
+ "step 3000: train loss 1.7963, val loss 1.9243\n",
322
+ "step 3100: train loss 1.7691, val loss 1.9208\n",
323
+ "step 3200: train loss 1.7506, val loss 1.9092\n",
324
+ "step 3300: train loss 1.7548, val loss 1.9038\n",
325
+ "step 3400: train loss 1.7582, val loss 1.8960\n",
326
+ "step 3500: train loss 1.7376, val loss 1.8934\n",
327
+ "step 3600: train loss 1.7232, val loss 1.8888\n",
328
+ "step 3700: train loss 1.7280, val loss 1.8814\n",
329
+ "step 3800: train loss 1.7221, val loss 1.8951\n",
330
+ "step 3900: train loss 1.7228, val loss 1.8789\n",
331
+ "step 4000: train loss 1.7168, val loss 1.8635\n",
332
+ "step 4100: train loss 1.7168, val loss 1.8798\n",
333
+ "step 4200: train loss 1.7088, val loss 1.8672\n",
334
+ "step 4300: train loss 1.6995, val loss 1.8501\n",
335
+ "step 4400: train loss 1.7096, val loss 1.8686\n",
336
+ "step 4500: train loss 1.6907, val loss 1.8546\n",
337
+ "step 4600: train loss 1.6868, val loss 1.8348\n",
338
+ "step 4700: train loss 1.6786, val loss 1.8346\n",
339
+ "step 4800: train loss 1.6659, val loss 1.8445\n",
340
+ "step 4900: train loss 1.6711, val loss 1.8384\n",
341
+ "step 4999: train loss 1.6630, val loss 1.8230\n",
342
+ "\n",
343
+ "ROMEO:\n",
344
+ "But you far you\n",
345
+ "my swap with thus; come hath I uD\n",
346
+ "If sleemition of where's granded\n",
347
+ "Of their of tout the gortune upwon alond, liege man to is Iell this surpe\n",
348
+ "And than sleue thus mind, his by blow,\n",
349
+ "Virdty toward butied, Ditire spresiss with thou some not.\n",
350
+ "\n",
351
+ "LORIO:\n",
352
+ "I am part\n",
353
+ "But thou sging them but\n",
354
+ "shat secondes morry thou sovore.\n",
355
+ "\n",
356
+ "ISABUS:\n",
357
+ "What art sade but hither, thange e'en,\n",
358
+ "Protes as kingle me; an your tords whom are Ineal.\n",
359
+ "\n",
360
+ "MENENIUS:\n",
361
+ "But little sweet, hom, foust cerfort;\n",
362
+ "Winth hing diend enirs' tompy beds sick ways!\n",
363
+ "What curforself this grace. Won, passes us.\n",
364
+ "\n",
365
+ "BUCKINGHABY MARD:\n",
366
+ "Mether star: keep it any head which\n",
367
+ "He tall devioly that, out that confer old.\n",
368
+ "Our thy dears time.\n",
369
+ "Nay, the fragoly, pair, of new\n",
370
+ "my father, my lip Backnoward:\n",
371
+ "God therring for respide\n",
372
+ "What colvery, teminelyord, I mast,\n",
373
+ "While us that such differs I'll that confect I come,\n",
374
+ "But; man.\n",
375
+ "\n",
376
+ "VOLUMNIO:\n",
377
+ "Ontread confail with me. Humser dipporbried answeraw is codal one,\n",
378
+ "Onjestion, not or cheavess ensty with.\n",
379
+ "\n",
380
+ "GLOUCESTER:\n",
381
+ "\n",
382
+ "HENRY Mess to Lies?\n",
383
+ "Stand and these beguare youf stile that than war\n",
384
+ "offity are, I usquesch\n",
385
+ "Frown movhapty not duke with you addom\n",
386
+ "grack prowd--lost\n",
387
+ "But but they worse is senst my crunne undolier. But, beauts pruntaly; I stoll'ct my nor Murder, I sot, though who speak\n",
388
+ "Your bout told-man rathing if anyshal\n",
389
+ "epitence, tirre no the said he's,\n",
390
+ "Andis frultifs. what his lide? That mirdy this dudgetions?\n",
391
+ "\n",
392
+ "KING ARINIA:\n",
393
+ "I let holt not sucKether,\n",
394
+ "Whither, efore But lord: I, beget because at that his say\n",
395
+ "as to brought grave a donesmer all nobe.\n",
396
+ "\n",
397
+ "BUCKINGHUMBY:\n",
398
+ "Which forgeled! Came; nor thereforn's fiends strefet.\n",
399
+ "\n",
400
+ "PLORIA:\n",
401
+ "Yet to Capprohning, that brird\n",
402
+ "of say mover a desrick.\n",
403
+ "\n",
404
+ "MO\n",
405
+ "stompars, God the\n",
406
+ "citchard is high.\n",
407
+ "\n",
408
+ "Seth Second Methere:\n",
409
+ "Marrmat I unmale the bretcius unfoect that I would back where own thy lurges\n",
410
+ "And, iffillimorture:\n",
411
+ "As thou twand, York these that high praut.\n",
412
+ "Plafe merprates sure dread with her,\n",
413
+ "At not your must I suchon? too prant!\n",
414
+ "O 'hiles clight the bleave is graved before\n"
415
+ ]
416
+ }
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "source": [
422
+ "# generate from the model\n",
423
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
424
+ "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))"
425
+ ],
426
+ "metadata": {
427
+ "colab": {
428
+ "base_uri": "https://localhost:8080/"
429
+ },
430
+ "id": "M0qIA2GK2qzI",
431
+ "outputId": "86126a68-17b1-4171-920a-1d2df6fa3f1a"
432
+ },
433
+ "execution_count": 13,
434
+ "outputs": [
435
+ {
436
+ "output_type": "stream",
437
+ "name": "stdout",
438
+ "text": [
439
+ "\n",
440
+ "And thou to lesserve his his know'st broy by A towe than or fuch dight none worthy'st countinne, congess\n",
441
+ "our ire Iname's marriatate the entrity?\n",
442
+ "\n",
443
+ "COMIOLUS:\n",
444
+ "Yet there me your let thy by courtary, own but I cannot, to\n",
445
+ "you.\n",
446
+ "\n",
447
+ "MOth Osque, while and nett; pity, brow umput;\n",
448
+ "He betwered's prettedy if not you arter,\n",
449
+ "But woman furner his good me to ambled thy follows\n",
450
+ "Gents for you daying this distend and he but.\n",
451
+ "\n",
452
+ "COMINANUS:\n",
453
+ "But you know wish the wear? whoe not to breave maste gate?\n",
454
+ "Not, now you read own. Lo-honour shoes\n",
455
+ "honordore vilibert.\n",
456
+ "\n",
457
+ "ARTOS:\n",
458
+ "Nay, as Is theen, God\n",
459
+ "Were I saying cose\n",
460
+ "Will there's upon and tools.\n",
461
+ "\n",
462
+ "HORSIO:\n",
463
+ "Pomfort life?\n",
464
+ "Whereform make comps hersed, my what away,\n",
465
+ "Go'st Your haste entens, and succe?\n",
466
+ "\n",
467
+ "LORD RIARENCE:\n",
468
+ "Fies my like, wifch a my nobt.\n",
469
+ "And!\n",
470
+ "And ways. Whithing death.\n",
471
+ "\n",
472
+ "CORIOLUMNO:\n",
473
+ "It must I have grawits.-\n",
474
+ "Ris Gomisty yor then thin dot this no all-donged,\n",
475
+ "But quarry the latter: Have me the betime twooke steed to blood\n",
476
+ "That his rysour grower-foldds: bnot Plond,\n",
477
+ "By that all wittore old the malt our liight.\n",
478
+ "Would for not\n",
479
+ "And sabet I sout ofing more in must.\n",
480
+ "\n",
481
+ "MENENIUS:\n",
482
+ "Gor low your I standed\n",
483
+ "To heavy:\n",
484
+ "While to caid your inswoes!\n",
485
+ "Thrhing the princlusior lurmeng,\n",
486
+ "To Whie! entred mean the not, sare.\n",
487
+ "\n",
488
+ "BRUTUS:\n",
489
+ "Is my partend him, if Is verys be,\n",
490
+ "Whim you longs,\n",
491
+ "Say, his me. Murselets; not with is most.\n",
492
+ "\n",
493
+ "JOLINA:\n",
494
+ "That it where that thluse too the hath'd\n",
495
+ "unsomed of our heavis'd?\n",
496
+ "\n",
497
+ "So his were Clamind:\n",
498
+ "Ounly mistry's soul\n",
499
+ "To once myser flow\n",
500
+ "Which, then, whet must I as not drums as the ouch are\n",
501
+ "burnse contreased and in Comintity?\n",
502
+ "\n",
503
+ "Mistray is I curliented:\n",
504
+ "Thou herew bottust, How lad you blist a wear's art?\n",
505
+ "What the vave--batta thing with\n",
506
+ "that his my urtusaed and as mine, thus not,\n",
507
+ "May your pohed me mhalt livy very\n",
508
+ "But I sham I ham kitse, pean, for\n",
509
+ "was, ewith woll heave in thou art, dlignt,\n",
510
+ "Of fair Griward that remottes must;\n",
511
+ "Cadyfore the not lords not, I say's gener which of your rame? Istand my hearth\n",
512
+ "And thou alt nenget that shame\n",
513
+ "She with them kinderire it put this\n"
514
+ ]
515
+ }
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "source": [
521
+ "prompt = \"Once upon a time\"\n",
522
+ "context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)\n",
523
+ "print(decode(m.generate(context, max_new_tokens=200)[0].tolist()))"
524
+ ],
525
+ "metadata": {
526
+ "colab": {
527
+ "base_uri": "https://localhost:8080/"
528
+ },
529
+ "id": "Na0SThjv5-iz",
530
+ "outputId": "c649c8ad-42fe-4a77-a219-9dcb1857a9c0"
531
+ },
532
+ "execution_count": 14,
533
+ "outputs": [
534
+ {
535
+ "output_type": "stream",
536
+ "name": "stdout",
537
+ "text": [
538
+ "Once upon a times peacts mother saclaves is 'Then of my tonguen,\n",
539
+ "Thus are been\n",
540
+ "My my behot prilatte, what you brot,\n",
541
+ "Speeke there is my bud the be, 'smandion from me:\n",
542
+ "And the barttes, rechard, where capuse,\n",
543
+ "Rentent, I\n"
544
+ ]
545
+ }
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "source": [
551
+ "\n",
552
+ "# Save the model\n",
553
+ "torch.save(m.state_dict(), 'GPT_Shakespeare_language_model.pth')"
554
+ ],
555
+ "metadata": {
556
+ "id": "sfmRYo9h6B24"
557
+ },
558
+ "execution_count": 15,
559
+ "outputs": []
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "source": [
564
+ "# Load the model\n",
565
+ "loaded_model = BigramLanguageModel() # Initialize an instance of your model\n",
566
+ "loaded_model.load_state_dict(torch.load('GPT_Shakespeare_language_model.pth'))\n",
567
+ "loaded_model.to(device).eval() # Set the model to evaluation mode"
568
+ ],
569
+ "metadata": {
570
+ "colab": {
571
+ "base_uri": "https://localhost:8080/"
572
+ },
573
+ "id": "xO9JefxH6KHS",
574
+ "outputId": "d7f0191c-4e02-4ed7-ff4b-b6f25a538fe5"
575
+ },
576
+ "execution_count": 17,
577
+ "outputs": [
578
+ {
579
+ "output_type": "execute_result",
580
+ "data": {
581
+ "text/plain": [
582
+ "BigramLanguageModel(\n",
583
+ " (token_embedding_table): Embedding(65, 64)\n",
584
+ " (position_embedding_table): Embedding(32, 64)\n",
585
+ " (blocks): Sequential(\n",
586
+ " (0): Block(\n",
587
+ " (sa): MultiHeadAttention(\n",
588
+ " (heads): ModuleList(\n",
589
+ " (0-3): 4 x Head(\n",
590
+ " (key): Linear(in_features=64, out_features=16, bias=False)\n",
591
+ " (query): Linear(in_features=64, out_features=16, bias=False)\n",
592
+ " (value): Linear(in_features=64, out_features=16, bias=False)\n",
593
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
594
+ " )\n",
595
+ " )\n",
596
+ " (proj): Linear(in_features=64, out_features=64, bias=True)\n",
597
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
598
+ " )\n",
599
+ " (ffwd): FeedFoward(\n",
600
+ " (net): Sequential(\n",
601
+ " (0): Linear(in_features=64, out_features=256, bias=True)\n",
602
+ " (1): ReLU()\n",
603
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
604
+ " (3): Dropout(p=0.0, inplace=False)\n",
605
+ " )\n",
606
+ " )\n",
607
+ " (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
608
+ " (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
609
+ " )\n",
610
+ " (1): Block(\n",
611
+ " (sa): MultiHeadAttention(\n",
612
+ " (heads): ModuleList(\n",
613
+ " (0-3): 4 x Head(\n",
614
+ " (key): Linear(in_features=64, out_features=16, bias=False)\n",
615
+ " (query): Linear(in_features=64, out_features=16, bias=False)\n",
616
+ " (value): Linear(in_features=64, out_features=16, bias=False)\n",
617
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
618
+ " )\n",
619
+ " )\n",
620
+ " (proj): Linear(in_features=64, out_features=64, bias=True)\n",
621
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
622
+ " )\n",
623
+ " (ffwd): FeedFoward(\n",
624
+ " (net): Sequential(\n",
625
+ " (0): Linear(in_features=64, out_features=256, bias=True)\n",
626
+ " (1): ReLU()\n",
627
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
628
+ " (3): Dropout(p=0.0, inplace=False)\n",
629
+ " )\n",
630
+ " )\n",
631
+ " (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
632
+ " (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
633
+ " )\n",
634
+ " (2): Block(\n",
635
+ " (sa): MultiHeadAttention(\n",
636
+ " (heads): ModuleList(\n",
637
+ " (0-3): 4 x Head(\n",
638
+ " (key): Linear(in_features=64, out_features=16, bias=False)\n",
639
+ " (query): Linear(in_features=64, out_features=16, bias=False)\n",
640
+ " (value): Linear(in_features=64, out_features=16, bias=False)\n",
641
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
642
+ " )\n",
643
+ " )\n",
644
+ " (proj): Linear(in_features=64, out_features=64, bias=True)\n",
645
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
646
+ " )\n",
647
+ " (ffwd): FeedFoward(\n",
648
+ " (net): Sequential(\n",
649
+ " (0): Linear(in_features=64, out_features=256, bias=True)\n",
650
+ " (1): ReLU()\n",
651
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
652
+ " (3): Dropout(p=0.0, inplace=False)\n",
653
+ " )\n",
654
+ " )\n",
655
+ " (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
656
+ " (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
657
+ " )\n",
658
+ " (3): Block(\n",
659
+ " (sa): MultiHeadAttention(\n",
660
+ " (heads): ModuleList(\n",
661
+ " (0-3): 4 x Head(\n",
662
+ " (key): Linear(in_features=64, out_features=16, bias=False)\n",
663
+ " (query): Linear(in_features=64, out_features=16, bias=False)\n",
664
+ " (value): Linear(in_features=64, out_features=16, bias=False)\n",
665
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
666
+ " )\n",
667
+ " )\n",
668
+ " (proj): Linear(in_features=64, out_features=64, bias=True)\n",
669
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
670
+ " )\n",
671
+ " (ffwd): FeedFoward(\n",
672
+ " (net): Sequential(\n",
673
+ " (0): Linear(in_features=64, out_features=256, bias=True)\n",
674
+ " (1): ReLU()\n",
675
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
676
+ " (3): Dropout(p=0.0, inplace=False)\n",
677
+ " )\n",
678
+ " )\n",
679
+ " (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
680
+ " (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
681
+ " )\n",
682
+ " )\n",
683
+ " (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
684
+ " (lm_head): Linear(in_features=64, out_features=65, bias=True)\n",
685
+ ")"
686
+ ]
687
+ },
688
+ "metadata": {},
689
+ "execution_count": 17
690
+ }
691
+ ]
692
+ },
693
+ {
694
+ "cell_type": "code",
695
+ "source": [
696
+ "# generate from the model\n",
697
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
698
+ "print(decode(loaded_model.generate(context, max_new_tokens=2000)[0].tolist()))"
699
+ ],
700
+ "metadata": {
701
+ "colab": {
702
+ "base_uri": "https://localhost:8080/"
703
+ },
704
+ "id": "m46OnNXq6PAV",
705
+ "outputId": "e547525f-98c0-4355-92ef-559f6c2ba238"
706
+ },
707
+ "execution_count": 18,
708
+ "outputs": [
709
+ {
710
+ "output_type": "stream",
711
+ "name": "stdout",
712
+ "text": [
713
+ "\n",
714
+ "Forntlefires, love the done, or all love tears\n",
715
+ "That braud the strough.\n",
716
+ "\n",
717
+ "BUCHNIO:\n",
718
+ "Is\n",
719
+ "For that I hat deam throve? we parrlignos;\n",
720
+ "My bregain minousiner mile into the doth,\n",
721
+ "Warwien not his day hath;\n",
722
+ "Whose basy touther ploudde metornies'drey would be themseremes to have\n",
723
+ "You good accarm, menot wtoo cown:\n",
724
+ "Is have mostil\n",
725
+ "Before prunces.\n",
726
+ "\n",
727
+ "Speaking A-dught:\n",
728
+ "Whow 'sile her fry hath acvionce,\n",
729
+ "Your cange, side of-day; this I seep!\n",
730
+ "Aher approve; I\n",
731
+ "drumber, any till amberd, come it suffet nexwarrans\n",
732
+ "To hear you that what art thim for a dish! Whiler not some men;\n",
733
+ "Hareth, I am broth, thenese oof.\n",
734
+ "Croth before wortune's hande and if brote\n",
735
+ "Come andmitation it. Tentess I what\n",
736
+ "That ascess Weringmans, te us;\n",
737
+ "And your Servant-thy moime, that whose.\n",
738
+ "\n",
739
+ "CORIOLANUS:\n",
740
+ "Now, stay to the resmorn?\n",
741
+ "\n",
742
+ "CRANGE:\n",
743
+ "It have pleave to some, for soul;\n",
744
+ "He fatelly here that you, hesseliemes five oldince\n",
745
+ "Our confolle, too you stay'd my being to't,\n",
746
+ "My lord I am then most the knows doot hid-gress.\n",
747
+ "\n",
748
+ "KING RICHARD GDITH:\n",
749
+ "As beconsure! So youil heart fear; and whilook my arm verpast,\n",
750
+ "And staven to fathy down I vir all prace,\n",
751
+ "And be betcasion your balt, to draying and the bottchmy,\n",
752
+ "The griake must worse it. As I have owle well I who stray good.\n",
753
+ "\n",
754
+ "My anviusice: andress unthat fonds of oad;\n",
755
+ "ne's eye the notraing and timer cimmman:\n",
756
+ "Heth lain. What's is the castad,\n",
757
+ "And their speake fatwort off.\n",
758
+ "\n",
759
+ "Shy:\n",
760
+ "What marry; thysele, time onge,\n",
761
+ "And by bown, merpety to to of crive thou secam.\n",
762
+ "\n",
763
+ "QUEEN VINCENTIO:\n",
764
+ "How my bold; good poson\n",
765
+ "I finly torthus.\n",
766
+ "Our you if your aware watly sweet\n",
767
+ "On all fair livishts thee our then plast banoting\n",
768
+ "What have duckn, so\n",
769
+ "them the hostfeive.\n",
770
+ "\n",
771
+ "HIRDIO:\n",
772
+ "\n",
773
+ "GLOUCERDIO:\n",
774
+ "And all capure Toncant mack.\n",
775
+ "\n",
776
+ "CAPULET:\n",
777
+ "O mean bodams'd my tone thy wralf thee wilth\n",
778
+ "And rencrown prow my ear them lovery\n",
779
+ "Coringlike hath in recond:\n",
780
+ "You will you from of God their all and not mine:\n",
781
+ "With doess be Sives?\n",
782
+ "So regort it thy mart solued sgaft world of him,\n",
783
+ "What'st in else agged namfutiol.\n",
784
+ "\n",
785
+ "ANGENO:\n",
786
+ "For that it you briave lay to your unpalssi\n"
787
+ ]
788
+ }
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "source": [],
794
+ "metadata": {
795
+ "id": "2GpnegQc8A9R"
796
+ },
797
+ "execution_count": null,
798
+ "outputs": []
799
+ }
800
+ ]
801
+ }
GPT_Shakespeare_language_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d2cbe6d12d0d566b1bfb4aed76ee5ca713cfda0f68d75666cb01675edcf72a2
3
+ size 946452
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import numpy as np
5
+ import random
6
+ import re
7
+ import gradio as gr
8
+
9
+ # hyperparameters
10
+ batch_size = 16 # how many independent sequences will we process in parallel?
11
+ block_size = 32 # what is the maximum context length for predictions?
12
+ max_iters = 5000
13
+ eval_interval = 100
14
+ learning_rate = 1e-3
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ eval_iters = 200
17
+ n_embd = 64
18
+ n_head = 4
19
+ n_layer = 4
20
+ dropout = 0.0
21
+ # ------------
22
+
23
+ torch.manual_seed(1337)
24
+
25
+ class Head(nn.Module):
26
+ """ one head of self-attention """
27
+
28
+ def __init__(self, head_size):
29
+ super().__init__()
30
+ self.key = nn.Linear(n_embd, head_size, bias=False)
31
+ self.query = nn.Linear(n_embd, head_size, bias=False)
32
+ self.value = nn.Linear(n_embd, head_size, bias=False)
33
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
34
+
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, x):
38
+ B,T,C = x.shape
39
+ k = self.key(x) # (B,T,C)
40
+ q = self.query(x) # (B,T,C)
41
+ # compute attention scores ("affinities")
42
+ wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
43
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
44
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
45
+ wei = self.dropout(wei)
46
+ # perform the weighted aggregation of the values
47
+ v = self.value(x) # (B,T,C)
48
+ out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
49
+ return out
50
+
51
+ class MultiHeadAttention(nn.Module):
52
+ """ multiple heads of self-attention in parallel """
53
+
54
+ def __init__(self, num_heads, head_size):
55
+ super().__init__()
56
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
57
+ self.proj = nn.Linear(n_embd, n_embd)
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ def forward(self, x):
61
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
62
+ out = self.dropout(self.proj(out))
63
+ return out
64
+
65
+ class FeedFoward(nn.Module):
66
+ """ a simple linear layer followed by a non-linearity """
67
+
68
+ def __init__(self, n_embd):
69
+ super().__init__()
70
+ self.net = nn.Sequential(
71
+ nn.Linear(n_embd, 4 * n_embd),
72
+ nn.ReLU(),
73
+ nn.Linear(4 * n_embd, n_embd),
74
+ nn.Dropout(dropout),
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+ class Block(nn.Module):
81
+ """ Transformer block: communication followed by computation """
82
+
83
+ def __init__(self, n_embd, n_head):
84
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
85
+ super().__init__()
86
+ head_size = n_embd // n_head
87
+ self.sa = MultiHeadAttention(n_head, head_size)
88
+ self.ffwd = FeedFoward(n_embd)
89
+ self.ln1 = nn.LayerNorm(n_embd)
90
+ self.ln2 = nn.LayerNorm(n_embd)
91
+
92
+ def forward(self, x):
93
+ x = x + self.sa(self.ln1(x))
94
+ x = x + self.ffwd(self.ln2(x))
95
+ return x
96
+
97
+ # super simple bigram model
98
+ class BigramLanguageModel(nn.Module):
99
+ def __init__(self, dataset_text, n_embd):
100
+ super().__init__()
101
+
102
+ # Compute character-related parameters
103
+ self.chars = sorted(list(set(dataset_text)))
104
+ self.vocab_size = len(self.chars)
105
+ self.stoi = {ch: i for i, ch in enumerate(self.chars)}
106
+ self.itos = {i: ch for ch, i in self.stoi.items()}
107
+
108
+ self.token_embedding_table = nn.Embedding(self.vocab_size, n_embd)
109
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
110
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
111
+ self.ln_f = nn.LayerNorm(n_embd)
112
+ self.lm_head = nn.Linear(n_embd, self.vocab_size)
113
+ self.encode = lambda s: [self.stoi[c] for c in s] # encoder: take a string, output a list of integers
114
+ self.decode = lambda l: ''.join([self.itos[i] for i in l]) # decoder: take a list of integers, output a string
115
+
116
+
117
+ def forward(self, idx, targets=None):
118
+ B, T = idx.shape
119
+
120
+ # idx and targets are both (B,T) tensor of integers
121
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
122
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
123
+ x = tok_emb + pos_emb # (B,T,C)
124
+ x = self.blocks(x) # (B,T,C)
125
+ x = self.ln_f(x) # (B,T,C)
126
+ logits = self.lm_head(x) # (B,T,vocab_size)
127
+
128
+ if targets is None:
129
+ loss = None
130
+ else:
131
+ B, T, C = logits.shape
132
+ logits = logits.view(B*T, C)
133
+ targets = targets.view(B*T)
134
+ loss = F.cross_entropy(logits, targets)
135
+
136
+ return logits, loss
137
+
138
+ def generate(self, idx, max_new_tokens):
139
+ # idx is (B, T) array of indices in the current context
140
+ for _ in range(max_new_tokens):
141
+ # crop idx to the last block_size tokens
142
+ idx_cond = idx[:, -block_size:]
143
+ # get the predictions
144
+ logits, loss = self(idx_cond)
145
+ # focus only on the last time step
146
+ logits = logits[:, -1, :] # becomes (B, C)
147
+ # apply softmax to get probabilities
148
+ probs = F.softmax(logits, dim=-1) # (B, C)
149
+ # sample from the distribution
150
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
151
+ # append sampled index to the running sequence
152
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
153
+ return idx
154
+
155
+ # Reading shakespeare data
156
+ with open('input.txt', 'r', encoding='utf-8') as f:
157
+ shakespeare_text = f.read()
158
+
159
+
160
+ # Reading wikipedia data
161
+ DATA_PATH = 'wikisent2.txt'
162
+ # load wikipedia sentences
163
+ with open(DATA_PATH, 'r') as f:
164
+ lines = f.read().splitlines()
165
+
166
+ # Selecting 250k lines from the dataset.
167
+ random.seed(42)
168
+ texts = random.choices(lines, k=250000)
169
+ del lines
170
+
171
+ def preprocess(text):
172
+ text = re.sub('@.*?\s+', '', text) # Remove mentions
173
+ text = re.sub('#.*?\s+', '', text) # Remove hashtags
174
+ text = re.sub(r'https?:\/\/.*[\r\n]*', '', text) # Remove URLs
175
+ text = re.sub(r'[^\w\s\'.]', '', text) # Remove special characters except for single quotes and periods
176
+ text = re.sub('\s+', ' ', text) # Replace multiple spaces with a single space
177
+ text = re.sub('^\d+\s*|^\d+\.\d+\s*|^\d+\.\d+\.\d+\s*', '', text) # Remove digits at the start of sentences
178
+ text = text.strip() # Remove leading and trailing whitespace
179
+ return text
180
+
181
+ wiki_text = [preprocess(t) for t in texts]
182
+ wiki_text = '\n'.join(wiki_text)
183
+
184
+ # Load the shakespeaere model
185
+ shakespeare_model = BigramLanguageModel(shakespeare_text, n_embd).to(device) # Initialize an instance of your model
186
+ shakespeare_model.load_state_dict(torch.load('shakespeaere_language_model.pth', map_location=torch.device('cpu')))
187
+ shakespeare_model.eval() # Set the model to evaluation mode
188
+
189
+ # Load the wikipedia model
190
+ wikipedia_model = BigramLanguageModel(wiki_text, n_embd).to(device) # Initialize an instance of your model
191
+ wikipedia_model.load_state_dict(torch.load('wikipedia_language_model.pth', map_location=torch.device('cpu')))
192
+ wikipedia_model.eval() # Set the model to evaluation mode
193
+
194
+
195
+ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
196
+ if prompt:
197
+ context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
198
+ else:
199
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
200
+ text_output = shakespeare_model.decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
201
+ return text_output
202
+
203
+
204
+ def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
205
+ if prompt:
206
+ context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
207
+ else:
208
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
209
+ text_output = wikipedia_model.decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
210
+ return text_output
211
+
212
+
213
+ title = "Nano GPT"
214
+
215
+ description1 = "Nano GPT trained on <a href='https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences'>Shakespeare dataset</a>. It is trained on a very small amount of data to understand how GPT's are trained and built. The implementation can be found <a href='https://github.com/karpathy/nanoGPT'>here.</a>"
216
+
217
+ shakespeare_interface = gr.Interface(generate_shakespeare_outputs,
218
+ inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="Once upon a time,"),
219
+ gr.Slider(minimum=100, maximum=5000, step=100, value=2000, label="Max new tokens")],
220
+ outputs=gr.Textbox(label="Output generated", type="text"), description=description1)
221
+
222
+ description2 = "Nano GPT trained on <a href='https://github.com/karpathy/char-rnn/blob/6f9487a6fe5b420b7ca9afb0d7c078e37c1d1b4e/data/tinyshakespeare/input.txt'>Wikipedia dataset</a>. It is trained on a very small amount of data to understand how GPT's are trained and built. The implementation can be found <a href='https://github.com/karpathy/nanoGPT'>here.</a>"
223
+
224
+ wiki_interface = gr.Interface(generate_wikipedia_outputs,
225
+ inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="James Bond"),
226
+ gr.Slider(minimum=100, maximum=5000, step=100, value=2000, label="Max new tokens")],
227
+ outputs=gr.Textbox(label="Output generated", type="text"), description=description2)
228
+
229
+ demo = gr.TabbedInterface([shakespeare_interface, wiki_interface], tab_names=["Shakespeare Data", "Wikipedia Data"],
230
+ title=title)
231
+
232
+
233
+ demo.launch()
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
news.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b3df2ebf6b3b72bed68f665853caf5ab68345ba2f67618d6dae52add20a850d
3
+ size 63807429
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ gradio