ravi.naik commited on
Commit
213d16f
·
1 Parent(s): 73580b8

Added training, inference and gradio UI code

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ERA SESSION21
3
  emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: blue
@@ -10,4 +10,57 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: "ERA SESSION21: GPT from scratch"
3
  emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: blue
 
10
  license: mit
11
  ---
12
 
13
+ ### Results
14
+ **Bigram Base model training and results**
15
+
16
+ ![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/4cc02d93-98fc-4114-a4c9-8a3c249eaad3)
17
+
18
+ **GPT Model training results**
19
+
20
+ ![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/95dcde00-bf20-4853-ad20-fa67c1046f6b)
21
+
22
+ #### Generation Output:
23
+ ```python
24
+ model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
25
+ results = generate("hello", model, block_size, 1000, device)
26
+ print(results)
27
+ ```
28
+ ```
29
+ hellows thence grown from thee.
30
+ Since thou hast raim, thou thast well were quarterned; and
31
+ ever man tree can saw for words word from her at hour
32
+ Whiles contrations or devoided from ere years;
33
+ Yea, foul vice, indelice on the bird of the
34
+ noble of Hermione.
35
+
36
+ PARIS:
37
+ Sir, adies, sir, hate no choping but to your good.
38
+
39
+ HENRY BOLINGBROKE:
40
+ Yes, to ask you might, foreweed.
41
+
42
+ WARCK:
43
+ 'Tis he made moust true.
44
+
45
+ RORSET:
46
+ It is an hour fastal that cracknaf at the chase
47
+ Upon; you are your hearing news a daughter.
48
+
49
+ KING EDWARD IV:
50
+ Tut, Lord Warwick, thou shouldst aft Rutlansps?
51
+ Thou tust but back hild, he countemn'd my lady's seal,
52
+ For access dead the treature moon! and the Englisting!
53
+ Thy vage for yonder see thou be donen?
54
+ O, count thou dost not Romeo, thou pratheeo sir,
55
+ That sweet thou feigh with no past blood on
56
+ Be see, here through on that find bears, if an
57
+ pretterinctors three and aspect die meeds thou,
58
+ Behing mine of thy denigning state lain business?
59
+
60
+ SAMPSA:
61
+ Sir, ha! but thou refused? thyself food, gr
62
+ ```
63
+ ### Gradio Interface
64
+ ![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/f339ec6b-17b3-4de6-bbef-14eb2b3fac84)
65
+
66
+
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ import pathlib
5
+
6
+ from src.model import GPTModel
7
+ from src.inference import generate as generate_text
8
+ from src.utils import vocab_size
9
+
10
+ batch_size = 64
11
+ block_size = 256
12
+ max_iters = 5000
13
+ eval_interval = 500
14
+ learning_rate = 3e-4
15
+ device = "cuda:1" if torch.cuda.is_available() else "cpu"
16
+ eval_iters = 200
17
+ n_embeds = 384
18
+ n_heads = 6
19
+ n_layers = 6
20
+ dropout = 0.2
21
+
22
+
23
+ def load_model():
24
+ model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
25
+ return model
26
+
27
+
28
+ model = load_model()
29
+
30
+
31
+ def generate(prompt, max_new_tokens):
32
+ prompt = prompt.strip()
33
+ out = generate_text(prompt, model, block_size, max_new_tokens, device)
34
+ return {gpt_output: out}
35
+
36
+
37
+ with gr.Blocks() as app:
38
+ gr.Markdown("## ERA Session21 - GPT from scratch")
39
+ gr.Markdown(
40
+ """This is an implementation of GPT [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2s) by Andrej Karpathy.
41
+
42
+ Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION21).
43
+
44
+ Dataset used to train: [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
45
+ """
46
+ )
47
+ with gr.Row():
48
+ with gr.Column():
49
+ prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
50
+ max_new_tokens = gr.Slider(
51
+ minimum=10,
52
+ maximum=2500,
53
+ value=100,
54
+ step=10,
55
+ label="Select Number of Tokens to be Generated",
56
+ interactive=True,
57
+ )
58
+ submit_btn = gr.Button(value="Generate")
59
+
60
+ with gr.Column():
61
+ gpt_output = gr.TextArea(
62
+ label="Text Generated by GPT",
63
+ show_label=True,
64
+ max_lines=100,
65
+ interactive=False,
66
+ )
67
+
68
+ submit_btn.click(
69
+ generate,
70
+ inputs=[prompt_box, max_new_tokens],
71
+ outputs=[gpt_output],
72
+ )
73
+
74
+ app.launch()
checkpoints/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8b930ee87e1eecc6a03bc49983a81fd11aaa95f4cd5e1d64091d6107827811b
3
+ size 52698997
data/input.txt ADDED
The diff for this file is too large to render. See raw diff
 
experiments/bigram.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ batch_size = 32
6
+ block_size = 8
7
+ max_iters = 3000
8
+ eval_interval = 300
9
+ learning_rate = 1e-2
10
+ device = "cuda:1" if torch.cuda.is_available() else "cpu"
11
+ eval_iters = 200
12
+
13
+ torch.manual_seed(1123)
14
+
15
+ with open("input.txt") as f:
16
+ text = f.read()
17
+
18
+ chars = sorted(list(set(text)))
19
+ vocab_size = len(chars)
20
+
21
+ stoi = {ch: i for i, ch in enumerate(chars)}
22
+ itos = {i: ch for i, ch in enumerate(chars)}
23
+
24
+ encode = lambda s: [stoi[c] for c in s]
25
+ decode = lambda l: "".join([itos[i] for i in l])
26
+
27
+ data = torch.tensor(encode(text), dtype=torch.long)
28
+ n = int(0.9 * len(data))
29
+ train_data = data[:n]
30
+ val_data = data[n:]
31
+
32
+
33
+ def get_batch(split):
34
+ data = train_data if split == "train" else val_data
35
+ ix = torch.randint(len(data) - block_size, (batch_size,))
36
+ x = torch.stack([data[i : i + block_size] for i in ix])
37
+ y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
38
+ return x, y
39
+
40
+
41
+ @torch.no_grad()
42
+ def estimate_loss(model: nn.Module):
43
+ out = {}
44
+ model.eval()
45
+ for split in ["train", "val"]:
46
+ losses = torch.zeros(eval_iters)
47
+ for k in range(eval_iters):
48
+ X, Y = get_batch(split)
49
+ X, Y = X.to(device), Y.to(device)
50
+ logits, loss = model(X, Y)
51
+ losses[k] = loss.item()
52
+ out[split] = losses.mean()
53
+ model.train()
54
+ return out
55
+
56
+
57
+ class BigramLanguageModel(nn.Module):
58
+ def __init__(self, vocab_size):
59
+ super().__init__()
60
+ self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
61
+
62
+ def forward(self, idx, targets=None):
63
+ logits = self.token_embedding_table(idx) # BTC
64
+ loss = None
65
+ if targets is not None:
66
+ B, T, C = logits.shape
67
+ logits = logits.view(B * T, C)
68
+ targets = targets.view(B * T)
69
+ loss = F.cross_entropy(logits, targets)
70
+ return logits, loss
71
+
72
+ def generate(self, idx, max_new_tokens):
73
+ for _ in range(max_new_tokens):
74
+ logits, loss = self(idx) # BxTxC
75
+ logits = logits[:, -1, :] # BxC
76
+ probs = F.softmax(logits, dim=-1) # BxC
77
+ idx_next = torch.multinomial(probs, num_samples=1) # Bx1
78
+ idx = torch.cat((idx, idx_next), dim=1) # BxT+1
79
+
80
+ return idx
81
+
82
+
83
+ model = BigramLanguageModel(vocab_size)
84
+
85
+ model = model.to(device)
86
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
87
+
88
+ for iter in range(max_iters):
89
+ if iter % eval_interval == 0:
90
+ losses = estimate_loss(model)
91
+ print(
92
+ f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
93
+ )
94
+
95
+ xb, yb = get_batch("train")
96
+ xb, yb = xb.to(device), yb.to(device)
97
+
98
+ logits, loss = model(xb, yb)
99
+
100
+ optimizer.zero_grad(set_to_none=True)
101
+ loss.backward()
102
+ optimizer.step()
103
+
104
+
105
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
106
+ results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
107
+ print(results)
experiments/bigram_v2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ batch_size = 64
6
+ block_size = 256
7
+ max_iters = 5000
8
+ eval_interval = 500
9
+ learning_rate = 3e-4
10
+ device = "cuda:1" if torch.cuda.is_available() else "cpu"
11
+ eval_iters = 200
12
+ n_embeds = 384
13
+ n_heads = 6
14
+ n_layers = 6
15
+ dropout = 0.2
16
+
17
+ torch.manual_seed(1123)
18
+
19
+ with open("input.txt") as f:
20
+ text = f.read()
21
+
22
+ chars = sorted(list(set(text)))
23
+ vocab_size = len(chars)
24
+
25
+ stoi = {ch: i for i, ch in enumerate(chars)}
26
+ itos = {i: ch for i, ch in enumerate(chars)}
27
+
28
+
29
+ def encode(s):
30
+ return [stoi[c] for c in s]
31
+
32
+
33
+ def decode(l):
34
+ return "".join([itos[i] for i in l])
35
+
36
+
37
+ data = torch.tensor(encode(text), dtype=torch.long)
38
+ n = int(0.9 * len(data))
39
+ train_data = data[:n]
40
+ val_data = data[n:]
41
+
42
+
43
+ def get_batch(split):
44
+ data = train_data if split == "train" else val_data
45
+ ix = torch.randint(len(data) - block_size, (batch_size,))
46
+ x = torch.stack([data[i : i + block_size] for i in ix])
47
+ y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
48
+ return x, y
49
+
50
+
51
+ @torch.no_grad()
52
+ def estimate_loss(model: nn.Module):
53
+ out = {}
54
+ model.eval()
55
+ for split in ["train", "val"]:
56
+ losses = torch.zeros(eval_iters)
57
+ for k in range(eval_iters):
58
+ X, Y = get_batch(split)
59
+ X, Y = X.to(device), Y.to(device)
60
+ logits, loss = model(X, Y)
61
+ losses[k] = loss.item()
62
+ out[split] = losses.mean()
63
+ model.train()
64
+ return out
65
+
66
+
67
+ class Head(nn.Module):
68
+ def __init__(self, n_embed, head_size) -> None:
69
+ super().__init__()
70
+ self.key = nn.Linear(n_embed, head_size, bias=False)
71
+ self.query = nn.Linear(n_embed, head_size, bias=False)
72
+ self.value = nn.Linear(n_embed, head_size, bias=False)
73
+ self.dropout = nn.Dropout(dropout)
74
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
75
+
76
+ def forward(self, x):
77
+ B, T, C = x.shape
78
+ k = self.key(x)
79
+ q = self.query(x)
80
+ wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
81
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
82
+ wei = F.softmax(wei, dim=-1)
83
+ wei = self.dropout(wei)
84
+ v = self.value(x)
85
+ out = wei @ v
86
+ return out
87
+
88
+
89
+ class MultiHeadAttention(nn.Module):
90
+ def __init__(self, n_heads, n_embeds, head_size):
91
+ super().__init__()
92
+ self.heads = nn.ModuleList([Head(n_embeds, head_size) for _ in range(n_heads)])
93
+ self.proj = nn.Linear(n_embeds, n_embeds)
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ def forward(self, x):
97
+ x = torch.cat([h(x) for h in self.heads], dim=-1)
98
+ x = self.proj(x)
99
+ x = self.dropout(x)
100
+ return x
101
+
102
+
103
+ class FeedForward(nn.Module):
104
+ def __init__(self, n_embeds):
105
+ super().__init__()
106
+ self.net = nn.Sequential(
107
+ nn.Linear(n_embeds, 4 * n_embeds),
108
+ nn.ReLU(),
109
+ nn.Linear(4 * n_embeds, n_embeds),
110
+ nn.Dropout(dropout),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.net(x)
115
+
116
+
117
+ class Block(nn.Module):
118
+ def __init__(self, n_embeds, n_heads):
119
+ super().__init__()
120
+ head_size = n_embeds // n_heads
121
+ self.sa_heads = MultiHeadAttention(n_heads, n_embeds, head_size)
122
+ self.ffwd = FeedForward(n_embeds)
123
+ self.ln1 = nn.LayerNorm(n_embeds)
124
+ self.ln2 = nn.LayerNorm(n_embeds)
125
+
126
+ def forward(self, x):
127
+ x = x + self.sa_heads(self.ln1(x))
128
+ x = x + self.ffwd(self.ln2(x))
129
+ return x
130
+
131
+
132
+ class BigramLanguageModel(nn.Module):
133
+ def __init__(self, vocab_size, n_embeds, block_size):
134
+ super().__init__()
135
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
136
+ self.position_embedding_table = nn.Embedding(block_size, n_embeds)
137
+ self.blocks = nn.Sequential(
138
+ *[Block(n_embeds, n_heads) for _ in range(n_layers)]
139
+ )
140
+ self.lnf = nn.LayerNorm(n_embeds)
141
+ self.lm_head = nn.Linear(n_embeds, vocab_size)
142
+
143
+ def forward(self, idx, targets=None):
144
+ B, T = idx.shape
145
+
146
+ tok_embeds = self.token_embedding_table(idx) # BxTxNemb
147
+ pos_embeds = self.position_embedding_table(
148
+ torch.arange(T, device=device)
149
+ ) # TXNemb
150
+
151
+ x = tok_embeds + pos_embeds # BxTxNemb
152
+ x = self.blocks(x)
153
+ x = self.lnf(x)
154
+ logits = self.lm_head(x) # BxTxVocabSize
155
+
156
+ loss = None
157
+ if targets is not None:
158
+ B, T, C = logits.shape
159
+ logits = logits.view(B * T, C)
160
+ targets = targets.view(B * T)
161
+ loss = F.cross_entropy(logits, targets)
162
+ return logits, loss
163
+
164
+ def generate(self, idx, max_new_tokens):
165
+ for _ in range(max_new_tokens):
166
+ idx_cond = idx[:, -block_size:]
167
+ logits, loss = self(idx_cond) # BxTxC
168
+ logits = logits[:, -1, :] # BxC
169
+ probs = F.softmax(logits, dim=-1) # BxC
170
+ idx_next = torch.multinomial(probs, num_samples=1) # Bx1
171
+ idx = torch.cat((idx, idx_next), dim=1) # BxT+1
172
+
173
+ return idx
174
+
175
+
176
+ model = BigramLanguageModel(vocab_size, n_embeds, block_size)
177
+
178
+ model = model.to(device)
179
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
180
+
181
+ for iter in range(max_iters):
182
+ if iter % eval_interval == 0:
183
+ losses = estimate_loss(model)
184
+ print(
185
+ f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
186
+ )
187
+
188
+ xb, yb = get_batch("train")
189
+ xb, yb = xb.to(device), yb.to(device)
190
+
191
+ logits, loss = model(xb, yb)
192
+
193
+ optimizer.zero_grad(set_to_none=True)
194
+ loss.backward()
195
+ optimizer.step()
196
+
197
+
198
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
199
+ results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
200
+ print(results)
experiments/exp.ipynb ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "--2023-10-27 16:11:32-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
13
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...\n",
14
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... "
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "connected.\n",
22
+ "HTTP request sent, awaiting response... 200 OK\n",
23
+ "Length: 1115394 (1.1M) [text/plain]\n",
24
+ "Saving to: ‘input.txt.1’\n",
25
+ "\n",
26
+ "input.txt.1 100%[===================>] 1.06M 734KB/s in 1.5s \n",
27
+ "\n",
28
+ "2023-10-27 16:11:36 (734 KB/s) - ‘input.txt.1’ saved [1115394/1115394]\n",
29
+ "\n"
30
+ ]
31
+ }
32
+ ],
33
+ "source": [
34
+ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 2,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "with open(\"input.txt\") as f:\n",
44
+ " text = f.read()"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "data": {
54
+ "text/plain": [
55
+ "'First Citizen:\\nBefore we proceed any further, hear'"
56
+ ]
57
+ },
58
+ "execution_count": 4,
59
+ "metadata": {},
60
+ "output_type": "execute_result"
61
+ }
62
+ ],
63
+ "source": [
64
+ "text[:50]"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 5,
70
+ "metadata": {},
71
+ "outputs": [
72
+ {
73
+ "name": "stdout",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "\n",
77
+ " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
78
+ "65\n"
79
+ ]
80
+ }
81
+ ],
82
+ "source": [
83
+ "chars = sorted(list(set(text)))\n",
84
+ "vocab_size = len(chars)\n",
85
+ "\n",
86
+ "print(\"\".join(chars))\n",
87
+ "print(vocab_size)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 7,
93
+ "metadata": {},
94
+ "outputs": [
95
+ {
96
+ "name": "stdout",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "[46, 47, 1, 58, 46, 43, 56, 43]\n",
100
+ "hi there\n"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "stoi = {ch: i for i, ch in enumerate(chars)}\n",
106
+ "itos = {i: ch for i, ch in enumerate(chars)}\n",
107
+ "\n",
108
+ "encode = lambda s: [stoi[c] for c in s]\n",
109
+ "decode = lambda l: \"\".join([itos[i] for i in l])\n",
110
+ "\n",
111
+ "print(encode(\"hi there\"))\n",
112
+ "\n",
113
+ "print(decode(encode(\"hi there\")))"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 8,
119
+ "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "torch.Size([1115394]) torch.int64\n",
126
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
127
+ " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
128
+ " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
129
+ " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
130
+ " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
131
+ " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "import torch\n",
137
+ "\n",
138
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
139
+ "print(data.shape, data.dtype)\n",
140
+ "print(data[:100])"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 9,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "n = int(0.9 * len(data))\n",
150
+ "train_data = data[:n]\n",
151
+ "val_data = data[n:]"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 10,
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "name": "stdout",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Inputs:\n",
164
+ "torch.Size([4, 8])\n",
165
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
166
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
167
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
168
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n",
169
+ "-----------\n",
170
+ "Targets:\n",
171
+ "torch.Size([4, 8])\n",
172
+ "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
173
+ " [53, 56, 1, 58, 46, 39, 58, 1],\n",
174
+ " [58, 1, 58, 46, 39, 58, 1, 46],\n",
175
+ " [17, 27, 10, 0, 21, 1, 54, 39]])\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "torch.manual_seed(1337)\n",
181
+ "batch_size = 4\n",
182
+ "block_size = 8\n",
183
+ "\n",
184
+ "\n",
185
+ "def get_batch(split):\n",
186
+ " data = train_data if split == \"train\" else val_data\n",
187
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
188
+ " x = torch.stack([data[i : i + block_size] for i in ix])\n",
189
+ " y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])\n",
190
+ " return x, y\n",
191
+ "\n",
192
+ "\n",
193
+ "xb, yb = get_batch(\"train\")\n",
194
+ "print(\"Inputs:\")\n",
195
+ "print(xb.shape)\n",
196
+ "print(xb)\n",
197
+ "\n",
198
+ "print(\"-----------\")\n",
199
+ "print(\"Targets:\")\n",
200
+ "print(yb.shape)\n",
201
+ "print(yb)"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 11,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "import torch.nn as nn\n",
211
+ "from torch.nn import functional as F\n",
212
+ "\n",
213
+ "\n",
214
+ "class BigramLanguageModel(nn.Module):\n",
215
+ " def __init__(self, vocab_size):\n",
216
+ " super().__init__()\n",
217
+ " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
218
+ "\n",
219
+ " def forward(self, idx, targets):\n",
220
+ " logits = self.token_embedding_table(idx)\n",
221
+ "\n",
222
+ " return logits"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 12,
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "torch.Size([4, 8, 65])\n"
235
+ ]
236
+ }
237
+ ],
238
+ "source": [
239
+ "m = BigramLanguageModel(vocab_size)\n",
240
+ "out = m(xb, yb)\n",
241
+ "print(out.shape) # B,T,C -> 4X8X65"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 15,
247
+ "metadata": {},
248
+ "outputs": [
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "torch.Size([32, 65])\n",
254
+ "tensor(4.5262, grad_fn=<NllLossBackward0>)\n"
255
+ ]
256
+ }
257
+ ],
258
+ "source": [
259
+ "class BigramLanguageModel(nn.Module):\n",
260
+ " def __init__(self, vocab_size):\n",
261
+ " super().__init__()\n",
262
+ " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
263
+ "\n",
264
+ " def forward(self, idx, targets=None):\n",
265
+ " logits = self.token_embedding_table(idx) # BTC\n",
266
+ " loss = None\n",
267
+ " if targets is not None:\n",
268
+ " B, T, C = logits.shape\n",
269
+ " logits = logits.view(B * T, C)\n",
270
+ " targets = targets.view(B * T)\n",
271
+ " loss = F.cross_entropy(logits, targets)\n",
272
+ " return logits, loss\n",
273
+ "\n",
274
+ " def generate(self, idx, max_new_tokens):\n",
275
+ " for _ in range(max_new_tokens):\n",
276
+ " logits, loss = self(idx) # BxTxC\n",
277
+ " logits = logits[:, -1, :] # BxC\n",
278
+ " probs = F.softmax(logits, dim=-1) # BxC\n",
279
+ " idx_next = torch.multinomial(probs, num_samples=1) # Bx1\n",
280
+ " idx = torch.cat((idx, idx_next), dim=1) # BxT+1\n",
281
+ "\n",
282
+ " return idx\n",
283
+ "\n",
284
+ "\n",
285
+ "m = BigramLanguageModel(vocab_size)\n",
286
+ "logits, loss = m(xb, yb)\n",
287
+ "print(logits.shape) # B,T,C -> 4X8X65\n",
288
+ "print(loss)"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 16,
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "\n",
301
+ "'JgC.JZWqUkpdtkSpmzjM-,RqzgaN?vC:hgjnAnBZDga-APqGUH!WdCbIb;$DefOYbEvcaKGMmnO'q$KdS-'ZH\n",
302
+ ".YSqr'X!Q! d;\n"
303
+ ]
304
+ }
305
+ ],
306
+ "source": [
307
+ "idx = torch.zeros((1, 1), dtype=torch.long)\n",
308
+ "\n",
309
+ "results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
310
+ "\n",
311
+ "print(results)"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": 17,
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 19,
326
+ "metadata": {},
327
+ "outputs": [
328
+ {
329
+ "name": "stdout",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "2.4206888675689697\n"
333
+ ]
334
+ }
335
+ ],
336
+ "source": [
337
+ "batch_size = 32\n",
338
+ "\n",
339
+ "for steps in range(10000):\n",
340
+ " xb, yb = get_batch(\"train\")\n",
341
+ "\n",
342
+ " logits, loss = m(xb, yb)\n",
343
+ " optimizer.zero_grad(set_to_none=True)\n",
344
+ " loss.backward()\n",
345
+ " optimizer.step()\n",
346
+ "\n",
347
+ "print(loss.item())"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 20,
353
+ "metadata": {},
354
+ "outputs": [
355
+ {
356
+ "name": "stdout",
357
+ "output_type": "stream",
358
+ "text": [
359
+ "\n",
360
+ "Hou'sy'ting'stis's w ys'stholealy woawhimedy it 'save,\n",
361
+ "Too:Had wh fo an, ZCENERUCHENar ee onds, th h\n"
362
+ ]
363
+ }
364
+ ],
365
+ "source": [
366
+ "idx = torch.zeros((1, 1), dtype=torch.long)\n",
367
+ "\n",
368
+ "results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
369
+ "\n",
370
+ "print(results)"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": 28,
376
+ "metadata": {},
377
+ "outputs": [
378
+ {
379
+ "data": {
380
+ "text/plain": [
381
+ "torch.Size([4, 8, 16])"
382
+ ]
383
+ },
384
+ "execution_count": 28,
385
+ "metadata": {},
386
+ "output_type": "execute_result"
387
+ }
388
+ ],
389
+ "source": [
390
+ "B, T, C = 4, 8, 32\n",
391
+ "\n",
392
+ "x = torch.randn(B, T, C)\n",
393
+ "\n",
394
+ "head_size = 16\n",
395
+ "key = nn.Linear(C, head_size, bias=False)\n",
396
+ "query = nn.Linear(C, head_size, bias=False)\n",
397
+ "value = nn.Linear(C, head_size, bias=False)\n",
398
+ "k = key(x)\n",
399
+ "q = query(x)\n",
400
+ "wei = q @ k.transpose(-2, -1) * (head_size**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)\n",
401
+ "\n",
402
+ "tril = torch.tril(torch.ones(T, T))\n",
403
+ "wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
404
+ "wei = F.softmax(wei, dim=-1)\n",
405
+ "v = value(x)\n",
406
+ "out = wei @ v\n",
407
+ "\n",
408
+ "out.shape\n"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 29,
414
+ "metadata": {},
415
+ "outputs": [
416
+ {
417
+ "data": {
418
+ "text/plain": [
419
+ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
420
+ " [0.3325, 0.6675, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
421
+ " [0.3578, 0.2873, 0.3550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
422
+ " [0.2281, 0.1964, 0.2733, 0.3022, 0.0000, 0.0000, 0.0000, 0.0000],\n",
423
+ " [0.2851, 0.1588, 0.2068, 0.1436, 0.2057, 0.0000, 0.0000, 0.0000],\n",
424
+ " [0.2429, 0.1547, 0.1550, 0.1475, 0.2049, 0.0951, 0.0000, 0.0000],\n",
425
+ " [0.1573, 0.1838, 0.1123, 0.1680, 0.1528, 0.1194, 0.1063, 0.0000],\n",
426
+ " [0.1139, 0.1704, 0.0766, 0.1134, 0.1600, 0.1466, 0.1228, 0.0963]],\n",
427
+ " grad_fn=<SelectBackward0>)"
428
+ ]
429
+ },
430
+ "execution_count": 29,
431
+ "metadata": {},
432
+ "output_type": "execute_result"
433
+ }
434
+ ],
435
+ "source": [
436
+ "wei[0]\n"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": []
445
+ }
446
+ ],
447
+ "metadata": {
448
+ "kernelspec": {
449
+ "display_name": "Python 3",
450
+ "language": "python",
451
+ "name": "python3"
452
+ },
453
+ "language_info": {
454
+ "codemirror_mode": {
455
+ "name": "ipython",
456
+ "version": 3
457
+ },
458
+ "file_extension": ".py",
459
+ "mimetype": "text/x-python",
460
+ "name": "python",
461
+ "nbconvert_exporter": "python",
462
+ "pygments_lexer": "ipython3",
463
+ "version": "3.10.12"
464
+ }
465
+ },
466
+ "nbformat": 4,
467
+ "nbformat_minor": 2
468
+ }
gpt.ipynb ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Import Dependencies"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import torch\n"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "from src.model import GPTModel\n",
26
+ "from src.training import train\n",
27
+ "from src.inference import generate\n",
28
+ "from src.utils import vocab_size\n"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Decalre Hyperparams"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "batch_size = 64\n",
45
+ "block_size = 256\n",
46
+ "max_iters = 5000\n",
47
+ "eval_interval = 500\n",
48
+ "learning_rate = 3e-4\n",
49
+ "device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
50
+ "eval_iters = 200\n",
51
+ "n_embeds = 384\n",
52
+ "n_heads = 6\n",
53
+ "n_layers = 6\n",
54
+ "dropout = 0.2"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "## Initialize Model and Optimizer"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 6,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "model = GPTModel(vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device)\n",
71
+ "model = model.to(device)\n",
72
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "## Model Training"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 7,
85
+ "metadata": {},
86
+ "outputs": [
87
+ {
88
+ "name": "stdout",
89
+ "output_type": "stream",
90
+ "text": [
91
+ "Step 0: train loss 4.3249, val loss 4.3219\n",
92
+ "Step 500: train loss 2.0213, val loss 2.0953\n",
93
+ "Step 1000: train loss 1.6067, val loss 1.7813\n",
94
+ "Step 1500: train loss 1.4462, val loss 1.6380\n",
95
+ "Step 2000: train loss 1.3516, val loss 1.5810\n",
96
+ "Step 2500: train loss 1.2836, val loss 1.5376\n",
97
+ "Step 3000: train loss 1.2309, val loss 1.5148\n",
98
+ "Step 3500: train loss 1.1910, val loss 1.4904\n",
99
+ "Step 4000: train loss 1.1522, val loss 1.4822\n",
100
+ "Step 4500: train loss 1.1186, val loss 1.4838\n"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "train(\n",
106
+ " model,\n",
107
+ " optimizer,\n",
108
+ " max_iters,\n",
109
+ " eval_interval,\n",
110
+ " eval_iters,\n",
111
+ " block_size,\n",
112
+ " batch_size,\n",
113
+ " device,\n",
114
+ ")\n"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "metadata": {},
120
+ "source": [
121
+ "## Load the model and Generate text"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 4,
127
+ "metadata": {},
128
+ "outputs": [
129
+ {
130
+ "name": "stdout",
131
+ "output_type": "stream",
132
+ "text": [
133
+ "hellows thence grown from thee.\n",
134
+ "Since thou hast raim, thou thast well were quarterned; and\n",
135
+ "ever man tree can saw for words word from her at hour\n",
136
+ "Whiles contrations or devoided from ere years;\n",
137
+ "Yea, foul vice, indelice on the bird of the\n",
138
+ "noble of Hermione.\n",
139
+ "\n",
140
+ "PARIS:\n",
141
+ "Sir, adies, sir, hate no choping but to your good.\n",
142
+ "\n",
143
+ "HENRY BOLINGBROKE:\n",
144
+ "Yes, to ask you might, foreweed.\n",
145
+ "\n",
146
+ "WARCK:\n",
147
+ "'Tis he made moust true.\n",
148
+ "\n",
149
+ "RORSET:\n",
150
+ "It is an hour fastal that cracknaf at the chase\n",
151
+ "Upon; you are your hearing news a daughter.\n",
152
+ "\n",
153
+ "KING EDWARD IV:\n",
154
+ "Tut, Lord Warwick, thou shouldst aft Rutlansps?\n",
155
+ "Thou tust but back hild, he countemn'd my lady's seal,\n",
156
+ "For access dead the treature moon! and the Englisting!\n",
157
+ "Thy vage for yonder see thou be donen?\n",
158
+ "O, count thou dost not Romeo, thou pratheeo sir,\n",
159
+ "That sweet thou feigh with no past blood on\n",
160
+ "Be see, here through on that find bears, if an\n",
161
+ "pretterinctors three and aspect die meeds thou,\n",
162
+ "Behing mine of thy denigning state lain business?\n",
163
+ "\n",
164
+ "SAMPSA:\n",
165
+ "Sir, ha! but thou refused? thyself food, gr\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "model = torch.load(\"checkpoints/model.pth\", map_location={\"cpu\": device})\n",
171
+ "results = generate(\"hello\", model, block_size, 1000, device)\n",
172
+ "print(results)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": []
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": []
188
+ }
189
+ ],
190
+ "metadata": {
191
+ "kernelspec": {
192
+ "display_name": "Python 3",
193
+ "language": "python",
194
+ "name": "python3"
195
+ },
196
+ "language_info": {
197
+ "codemirror_mode": {
198
+ "name": "ipython",
199
+ "version": 3
200
+ },
201
+ "file_extension": ".py",
202
+ "mimetype": "text/x-python",
203
+ "name": "python",
204
+ "nbconvert_exporter": "python",
205
+ "pygments_lexer": "ipython3",
206
+ "version": "3.10.12"
207
+ }
208
+ },
209
+ "nbformat": 4,
210
+ "nbformat_minor": 2
211
+ }
src/__pycache__/inference.cpython-310.pyc ADDED
Binary file (527 Bytes). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.77 kB). View file
 
src/__pycache__/training.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
src/inference.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.utils import encode, decode
3
+
4
+
5
+ def generate(prompt, model, block_size, max_new_tokens, device):
6
+ X = torch.tensor(encode(prompt), dtype=torch.long, device=device)
7
+ X = X[:block_size].unsqueeze(0)
8
+ results = decode(model.generate(X, max_new_tokens=max_new_tokens)[0].tolist())
9
+ return results
src/model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Head(nn.Module):
7
+ def __init__(self, n_embeds, head_size, block_size, dropout) -> None:
8
+ super().__init__()
9
+ self.key = nn.Linear(n_embeds, head_size, bias=False)
10
+ self.query = nn.Linear(n_embeds, head_size, bias=False)
11
+ self.value = nn.Linear(n_embeds, head_size, bias=False)
12
+ self.dropout = nn.Dropout(dropout)
13
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
14
+
15
+ def forward(self, x):
16
+ B, T, C = x.shape
17
+ k = self.key(x)
18
+ q = self.query(x)
19
+ wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
20
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
21
+ wei = F.softmax(wei, dim=-1)
22
+ wei = self.dropout(wei)
23
+ v = self.value(x)
24
+ out = wei @ v
25
+ return out
26
+
27
+
28
+ class MultiHeadAttention(nn.Module):
29
+ def __init__(self, n_heads, n_embeds, head_size, block_size, dropout):
30
+ super().__init__()
31
+ self.heads = nn.ModuleList(
32
+ [Head(n_embeds, head_size, block_size, dropout) for _ in range(n_heads)]
33
+ )
34
+ self.proj = nn.Linear(n_embeds, n_embeds)
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, x):
38
+ x = torch.cat([h(x) for h in self.heads], dim=-1)
39
+ x = self.proj(x)
40
+ x = self.dropout(x)
41
+ return x
42
+
43
+
44
+ class FeedForward(nn.Module):
45
+ def __init__(self, n_embeds, dropout):
46
+ super().__init__()
47
+ self.net = nn.Sequential(
48
+ nn.Linear(n_embeds, 4 * n_embeds),
49
+ nn.ReLU(),
50
+ nn.Linear(4 * n_embeds, n_embeds),
51
+ nn.Dropout(dropout),
52
+ )
53
+
54
+ def forward(self, x):
55
+ return self.net(x)
56
+
57
+
58
+ class Decoder(nn.Module):
59
+ def __init__(self, n_embeds, n_heads, block_size, dropout):
60
+ super().__init__()
61
+ head_size = n_embeds // n_heads
62
+ self.sa_heads = MultiHeadAttention(
63
+ n_heads, n_embeds, head_size, block_size, dropout
64
+ )
65
+ self.ffwd = FeedForward(n_embeds, dropout)
66
+ self.ln1 = nn.LayerNorm(n_embeds)
67
+ self.ln2 = nn.LayerNorm(n_embeds)
68
+
69
+ def forward(self, x):
70
+ x = x + self.sa_heads(self.ln1(x))
71
+ x = x + self.ffwd(self.ln2(x))
72
+ return x
73
+
74
+
75
+ class GPTModel(nn.Module):
76
+ def __init__(
77
+ self, vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
78
+ ):
79
+ super().__init__()
80
+ self.device = device
81
+ self.block_size = block_size
82
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
83
+ self.position_embedding_table = nn.Embedding(block_size, n_embeds)
84
+ self.blocks = nn.Sequential(
85
+ *[Decoder(n_embeds, n_heads, block_size, dropout) for _ in range(n_layers)]
86
+ )
87
+ self.lnf = nn.LayerNorm(n_embeds)
88
+ self.lm_head = nn.Linear(n_embeds, vocab_size)
89
+
90
+ def forward(self, idx, targets=None):
91
+ B, T = idx.shape
92
+
93
+ tok_embeds = self.token_embedding_table(idx) # BxTxNemb
94
+ pos_embeds = self.position_embedding_table(
95
+ torch.arange(T, device=self.device)
96
+ ) # TXNemb
97
+
98
+ x = tok_embeds + pos_embeds # BxTxNemb
99
+ x = self.blocks(x)
100
+ x = self.lnf(x)
101
+ logits = self.lm_head(x) # BxTxVocabSize
102
+
103
+ loss = None
104
+ if targets is not None:
105
+ B, T, C = logits.shape
106
+ logits = logits.view(B * T, C)
107
+ targets = targets.view(B * T)
108
+ loss = F.cross_entropy(logits, targets)
109
+ return logits, loss
110
+
111
+ def generate(self, idx, max_new_tokens):
112
+ for _ in range(max_new_tokens):
113
+ idx_cond = idx[:, -self.block_size :]
114
+ logits, loss = self(idx_cond) # BxTxC
115
+ logits = logits[:, -1, :] # BxC
116
+ probs = F.softmax(logits, dim=-1) # BxC
117
+ idx_next = torch.multinomial(probs, num_samples=1) # Bx1
118
+ idx = torch.cat((idx, idx_next), dim=1) # BxT+1
119
+
120
+ return idx
src/training.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from src.utils import get_batch
5
+
6
+
7
+ @torch.no_grad()
8
+ def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device):
9
+ out = {}
10
+ model.eval()
11
+ for split in ["train", "val"]:
12
+ losses = torch.zeros(eval_iters)
13
+ for k in range(eval_iters):
14
+ X, Y = get_batch(split, block_size, batch_size)
15
+ X, Y = X.to(device), Y.to(device)
16
+ logits, loss = model(X, Y)
17
+ losses[k] = loss.item()
18
+ out[split] = losses.mean()
19
+ model.train()
20
+ return out
21
+
22
+
23
+ def train(
24
+ model,
25
+ optimizer,
26
+ max_iters,
27
+ eval_interval,
28
+ eval_iters,
29
+ block_size,
30
+ batch_size,
31
+ device,
32
+ ):
33
+ val_loss = None
34
+ for iter in range(max_iters):
35
+ if iter % eval_interval == 0:
36
+ losses = estimate_loss(model, eval_iters, block_size, batch_size, device)
37
+ print(
38
+ f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
39
+ )
40
+ if val_loss is not None:
41
+ if losses["val"] < val_loss:
42
+ torch.save(model, "checkpoints/model.pth")
43
+ else:
44
+ val_loss = losses["val"]
45
+
46
+ xb, yb = get_batch("train", block_size, batch_size)
47
+ xb, yb = xb.to(device), yb.to(device)
48
+
49
+ logits, loss = model(xb, yb)
50
+
51
+ optimizer.zero_grad(set_to_none=True)
52
+ loss.backward()
53
+ optimizer.step()
src/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ with open("data/input.txt") as f:
4
+ text = f.read()
5
+
6
+ chars = sorted(list(set(text)))
7
+ vocab_size = len(chars)
8
+
9
+ stoi = {ch: i for i, ch in enumerate(chars)}
10
+ itos = {i: ch for i, ch in enumerate(chars)}
11
+
12
+
13
+ def encode(s):
14
+ return [stoi[c] for c in s]
15
+
16
+
17
+ def decode(l):
18
+ return "".join([itos[i] for i in l])
19
+
20
+
21
+ data = torch.tensor(encode(text), dtype=torch.long)
22
+ n = int(0.9 * len(data))
23
+ train_data = data[:n]
24
+ val_data = data[n:]
25
+
26
+
27
+ def get_batch(split, block_size, batch_size):
28
+ data = train_data if split == "train" else val_data
29
+ ix = torch.randint(len(data) - block_size, (batch_size,))
30
+ x = torch.stack([data[i : i + block_size] for i in ix])
31
+ y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
32
+ return x, y