Spaces:
Sleeping
Sleeping
ravi.naik
commited on
Commit
·
213d16f
1
Parent(s):
73580b8
Added training, inference and gradio UI code
Browse files- README.md +55 -2
- app.py +74 -0
- checkpoints/model.pth +3 -0
- data/input.txt +0 -0
- experiments/bigram.py +107 -0
- experiments/bigram_v2.py +200 -0
- experiments/exp.ipynb +468 -0
- gpt.ipynb +211 -0
- src/__pycache__/inference.cpython-310.pyc +0 -0
- src/__pycache__/model.cpython-310.pyc +0 -0
- src/__pycache__/training.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/inference.py +9 -0
- src/model.py +120 -0
- src/training.py +53 -0
- src/utils.py +32 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: ERA SESSION21
|
3 |
emoji: 🌍
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
@@ -10,4 +10,57 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: "ERA SESSION21: GPT from scratch"
|
3 |
emoji: 🌍
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
### Results
|
14 |
+
**Bigram Base model training and results**
|
15 |
+
|
16 |
+
![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/4cc02d93-98fc-4114-a4c9-8a3c249eaad3)
|
17 |
+
|
18 |
+
**GPT Model training results**
|
19 |
+
|
20 |
+
![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/95dcde00-bf20-4853-ad20-fa67c1046f6b)
|
21 |
+
|
22 |
+
#### Generation Output:
|
23 |
+
```python
|
24 |
+
model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
|
25 |
+
results = generate("hello", model, block_size, 1000, device)
|
26 |
+
print(results)
|
27 |
+
```
|
28 |
+
```
|
29 |
+
hellows thence grown from thee.
|
30 |
+
Since thou hast raim, thou thast well were quarterned; and
|
31 |
+
ever man tree can saw for words word from her at hour
|
32 |
+
Whiles contrations or devoided from ere years;
|
33 |
+
Yea, foul vice, indelice on the bird of the
|
34 |
+
noble of Hermione.
|
35 |
+
|
36 |
+
PARIS:
|
37 |
+
Sir, adies, sir, hate no choping but to your good.
|
38 |
+
|
39 |
+
HENRY BOLINGBROKE:
|
40 |
+
Yes, to ask you might, foreweed.
|
41 |
+
|
42 |
+
WARCK:
|
43 |
+
'Tis he made moust true.
|
44 |
+
|
45 |
+
RORSET:
|
46 |
+
It is an hour fastal that cracknaf at the chase
|
47 |
+
Upon; you are your hearing news a daughter.
|
48 |
+
|
49 |
+
KING EDWARD IV:
|
50 |
+
Tut, Lord Warwick, thou shouldst aft Rutlansps?
|
51 |
+
Thou tust but back hild, he countemn'd my lady's seal,
|
52 |
+
For access dead the treature moon! and the Englisting!
|
53 |
+
Thy vage for yonder see thou be donen?
|
54 |
+
O, count thou dost not Romeo, thou pratheeo sir,
|
55 |
+
That sweet thou feigh with no past blood on
|
56 |
+
Be see, here through on that find bears, if an
|
57 |
+
pretterinctors three and aspect die meeds thou,
|
58 |
+
Behing mine of thy denigning state lain business?
|
59 |
+
|
60 |
+
SAMPSA:
|
61 |
+
Sir, ha! but thou refused? thyself food, gr
|
62 |
+
```
|
63 |
+
### Gradio Interface
|
64 |
+
![image](https://github.com/RaviNaik/ERA-SESSION21/assets/23289802/f339ec6b-17b3-4de6-bbef-14eb2b3fac84)
|
65 |
+
|
66 |
+
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import pathlib
|
5 |
+
|
6 |
+
from src.model import GPTModel
|
7 |
+
from src.inference import generate as generate_text
|
8 |
+
from src.utils import vocab_size
|
9 |
+
|
10 |
+
batch_size = 64
|
11 |
+
block_size = 256
|
12 |
+
max_iters = 5000
|
13 |
+
eval_interval = 500
|
14 |
+
learning_rate = 3e-4
|
15 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
16 |
+
eval_iters = 200
|
17 |
+
n_embeds = 384
|
18 |
+
n_heads = 6
|
19 |
+
n_layers = 6
|
20 |
+
dropout = 0.2
|
21 |
+
|
22 |
+
|
23 |
+
def load_model():
|
24 |
+
model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
model = load_model()
|
29 |
+
|
30 |
+
|
31 |
+
def generate(prompt, max_new_tokens):
|
32 |
+
prompt = prompt.strip()
|
33 |
+
out = generate_text(prompt, model, block_size, max_new_tokens, device)
|
34 |
+
return {gpt_output: out}
|
35 |
+
|
36 |
+
|
37 |
+
with gr.Blocks() as app:
|
38 |
+
gr.Markdown("## ERA Session21 - GPT from scratch")
|
39 |
+
gr.Markdown(
|
40 |
+
"""This is an implementation of GPT [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2s) by Andrej Karpathy.
|
41 |
+
|
42 |
+
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION21).
|
43 |
+
|
44 |
+
Dataset used to train: [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
|
45 |
+
"""
|
46 |
+
)
|
47 |
+
with gr.Row():
|
48 |
+
with gr.Column():
|
49 |
+
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
|
50 |
+
max_new_tokens = gr.Slider(
|
51 |
+
minimum=10,
|
52 |
+
maximum=2500,
|
53 |
+
value=100,
|
54 |
+
step=10,
|
55 |
+
label="Select Number of Tokens to be Generated",
|
56 |
+
interactive=True,
|
57 |
+
)
|
58 |
+
submit_btn = gr.Button(value="Generate")
|
59 |
+
|
60 |
+
with gr.Column():
|
61 |
+
gpt_output = gr.TextArea(
|
62 |
+
label="Text Generated by GPT",
|
63 |
+
show_label=True,
|
64 |
+
max_lines=100,
|
65 |
+
interactive=False,
|
66 |
+
)
|
67 |
+
|
68 |
+
submit_btn.click(
|
69 |
+
generate,
|
70 |
+
inputs=[prompt_box, max_new_tokens],
|
71 |
+
outputs=[gpt_output],
|
72 |
+
)
|
73 |
+
|
74 |
+
app.launch()
|
checkpoints/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8b930ee87e1eecc6a03bc49983a81fd11aaa95f4cd5e1d64091d6107827811b
|
3 |
+
size 52698997
|
data/input.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
experiments/bigram.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
batch_size = 32
|
6 |
+
block_size = 8
|
7 |
+
max_iters = 3000
|
8 |
+
eval_interval = 300
|
9 |
+
learning_rate = 1e-2
|
10 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
11 |
+
eval_iters = 200
|
12 |
+
|
13 |
+
torch.manual_seed(1123)
|
14 |
+
|
15 |
+
with open("input.txt") as f:
|
16 |
+
text = f.read()
|
17 |
+
|
18 |
+
chars = sorted(list(set(text)))
|
19 |
+
vocab_size = len(chars)
|
20 |
+
|
21 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
22 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
23 |
+
|
24 |
+
encode = lambda s: [stoi[c] for c in s]
|
25 |
+
decode = lambda l: "".join([itos[i] for i in l])
|
26 |
+
|
27 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
28 |
+
n = int(0.9 * len(data))
|
29 |
+
train_data = data[:n]
|
30 |
+
val_data = data[n:]
|
31 |
+
|
32 |
+
|
33 |
+
def get_batch(split):
|
34 |
+
data = train_data if split == "train" else val_data
|
35 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
36 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
37 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
38 |
+
return x, y
|
39 |
+
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def estimate_loss(model: nn.Module):
|
43 |
+
out = {}
|
44 |
+
model.eval()
|
45 |
+
for split in ["train", "val"]:
|
46 |
+
losses = torch.zeros(eval_iters)
|
47 |
+
for k in range(eval_iters):
|
48 |
+
X, Y = get_batch(split)
|
49 |
+
X, Y = X.to(device), Y.to(device)
|
50 |
+
logits, loss = model(X, Y)
|
51 |
+
losses[k] = loss.item()
|
52 |
+
out[split] = losses.mean()
|
53 |
+
model.train()
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class BigramLanguageModel(nn.Module):
|
58 |
+
def __init__(self, vocab_size):
|
59 |
+
super().__init__()
|
60 |
+
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
|
61 |
+
|
62 |
+
def forward(self, idx, targets=None):
|
63 |
+
logits = self.token_embedding_table(idx) # BTC
|
64 |
+
loss = None
|
65 |
+
if targets is not None:
|
66 |
+
B, T, C = logits.shape
|
67 |
+
logits = logits.view(B * T, C)
|
68 |
+
targets = targets.view(B * T)
|
69 |
+
loss = F.cross_entropy(logits, targets)
|
70 |
+
return logits, loss
|
71 |
+
|
72 |
+
def generate(self, idx, max_new_tokens):
|
73 |
+
for _ in range(max_new_tokens):
|
74 |
+
logits, loss = self(idx) # BxTxC
|
75 |
+
logits = logits[:, -1, :] # BxC
|
76 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
77 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
78 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
79 |
+
|
80 |
+
return idx
|
81 |
+
|
82 |
+
|
83 |
+
model = BigramLanguageModel(vocab_size)
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
87 |
+
|
88 |
+
for iter in range(max_iters):
|
89 |
+
if iter % eval_interval == 0:
|
90 |
+
losses = estimate_loss(model)
|
91 |
+
print(
|
92 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
93 |
+
)
|
94 |
+
|
95 |
+
xb, yb = get_batch("train")
|
96 |
+
xb, yb = xb.to(device), yb.to(device)
|
97 |
+
|
98 |
+
logits, loss = model(xb, yb)
|
99 |
+
|
100 |
+
optimizer.zero_grad(set_to_none=True)
|
101 |
+
loss.backward()
|
102 |
+
optimizer.step()
|
103 |
+
|
104 |
+
|
105 |
+
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
106 |
+
results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
|
107 |
+
print(results)
|
experiments/bigram_v2.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
batch_size = 64
|
6 |
+
block_size = 256
|
7 |
+
max_iters = 5000
|
8 |
+
eval_interval = 500
|
9 |
+
learning_rate = 3e-4
|
10 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
11 |
+
eval_iters = 200
|
12 |
+
n_embeds = 384
|
13 |
+
n_heads = 6
|
14 |
+
n_layers = 6
|
15 |
+
dropout = 0.2
|
16 |
+
|
17 |
+
torch.manual_seed(1123)
|
18 |
+
|
19 |
+
with open("input.txt") as f:
|
20 |
+
text = f.read()
|
21 |
+
|
22 |
+
chars = sorted(list(set(text)))
|
23 |
+
vocab_size = len(chars)
|
24 |
+
|
25 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
26 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
27 |
+
|
28 |
+
|
29 |
+
def encode(s):
|
30 |
+
return [stoi[c] for c in s]
|
31 |
+
|
32 |
+
|
33 |
+
def decode(l):
|
34 |
+
return "".join([itos[i] for i in l])
|
35 |
+
|
36 |
+
|
37 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
38 |
+
n = int(0.9 * len(data))
|
39 |
+
train_data = data[:n]
|
40 |
+
val_data = data[n:]
|
41 |
+
|
42 |
+
|
43 |
+
def get_batch(split):
|
44 |
+
data = train_data if split == "train" else val_data
|
45 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
46 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
47 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
48 |
+
return x, y
|
49 |
+
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def estimate_loss(model: nn.Module):
|
53 |
+
out = {}
|
54 |
+
model.eval()
|
55 |
+
for split in ["train", "val"]:
|
56 |
+
losses = torch.zeros(eval_iters)
|
57 |
+
for k in range(eval_iters):
|
58 |
+
X, Y = get_batch(split)
|
59 |
+
X, Y = X.to(device), Y.to(device)
|
60 |
+
logits, loss = model(X, Y)
|
61 |
+
losses[k] = loss.item()
|
62 |
+
out[split] = losses.mean()
|
63 |
+
model.train()
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class Head(nn.Module):
|
68 |
+
def __init__(self, n_embed, head_size) -> None:
|
69 |
+
super().__init__()
|
70 |
+
self.key = nn.Linear(n_embed, head_size, bias=False)
|
71 |
+
self.query = nn.Linear(n_embed, head_size, bias=False)
|
72 |
+
self.value = nn.Linear(n_embed, head_size, bias=False)
|
73 |
+
self.dropout = nn.Dropout(dropout)
|
74 |
+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
B, T, C = x.shape
|
78 |
+
k = self.key(x)
|
79 |
+
q = self.query(x)
|
80 |
+
wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
|
81 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
82 |
+
wei = F.softmax(wei, dim=-1)
|
83 |
+
wei = self.dropout(wei)
|
84 |
+
v = self.value(x)
|
85 |
+
out = wei @ v
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
class MultiHeadAttention(nn.Module):
|
90 |
+
def __init__(self, n_heads, n_embeds, head_size):
|
91 |
+
super().__init__()
|
92 |
+
self.heads = nn.ModuleList([Head(n_embeds, head_size) for _ in range(n_heads)])
|
93 |
+
self.proj = nn.Linear(n_embeds, n_embeds)
|
94 |
+
self.dropout = nn.Dropout(dropout)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x = torch.cat([h(x) for h in self.heads], dim=-1)
|
98 |
+
x = self.proj(x)
|
99 |
+
x = self.dropout(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class FeedForward(nn.Module):
|
104 |
+
def __init__(self, n_embeds):
|
105 |
+
super().__init__()
|
106 |
+
self.net = nn.Sequential(
|
107 |
+
nn.Linear(n_embeds, 4 * n_embeds),
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(4 * n_embeds, n_embeds),
|
110 |
+
nn.Dropout(dropout),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
return self.net(x)
|
115 |
+
|
116 |
+
|
117 |
+
class Block(nn.Module):
|
118 |
+
def __init__(self, n_embeds, n_heads):
|
119 |
+
super().__init__()
|
120 |
+
head_size = n_embeds // n_heads
|
121 |
+
self.sa_heads = MultiHeadAttention(n_heads, n_embeds, head_size)
|
122 |
+
self.ffwd = FeedForward(n_embeds)
|
123 |
+
self.ln1 = nn.LayerNorm(n_embeds)
|
124 |
+
self.ln2 = nn.LayerNorm(n_embeds)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = x + self.sa_heads(self.ln1(x))
|
128 |
+
x = x + self.ffwd(self.ln2(x))
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
class BigramLanguageModel(nn.Module):
|
133 |
+
def __init__(self, vocab_size, n_embeds, block_size):
|
134 |
+
super().__init__()
|
135 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
|
136 |
+
self.position_embedding_table = nn.Embedding(block_size, n_embeds)
|
137 |
+
self.blocks = nn.Sequential(
|
138 |
+
*[Block(n_embeds, n_heads) for _ in range(n_layers)]
|
139 |
+
)
|
140 |
+
self.lnf = nn.LayerNorm(n_embeds)
|
141 |
+
self.lm_head = nn.Linear(n_embeds, vocab_size)
|
142 |
+
|
143 |
+
def forward(self, idx, targets=None):
|
144 |
+
B, T = idx.shape
|
145 |
+
|
146 |
+
tok_embeds = self.token_embedding_table(idx) # BxTxNemb
|
147 |
+
pos_embeds = self.position_embedding_table(
|
148 |
+
torch.arange(T, device=device)
|
149 |
+
) # TXNemb
|
150 |
+
|
151 |
+
x = tok_embeds + pos_embeds # BxTxNemb
|
152 |
+
x = self.blocks(x)
|
153 |
+
x = self.lnf(x)
|
154 |
+
logits = self.lm_head(x) # BxTxVocabSize
|
155 |
+
|
156 |
+
loss = None
|
157 |
+
if targets is not None:
|
158 |
+
B, T, C = logits.shape
|
159 |
+
logits = logits.view(B * T, C)
|
160 |
+
targets = targets.view(B * T)
|
161 |
+
loss = F.cross_entropy(logits, targets)
|
162 |
+
return logits, loss
|
163 |
+
|
164 |
+
def generate(self, idx, max_new_tokens):
|
165 |
+
for _ in range(max_new_tokens):
|
166 |
+
idx_cond = idx[:, -block_size:]
|
167 |
+
logits, loss = self(idx_cond) # BxTxC
|
168 |
+
logits = logits[:, -1, :] # BxC
|
169 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
170 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
171 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
172 |
+
|
173 |
+
return idx
|
174 |
+
|
175 |
+
|
176 |
+
model = BigramLanguageModel(vocab_size, n_embeds, block_size)
|
177 |
+
|
178 |
+
model = model.to(device)
|
179 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
180 |
+
|
181 |
+
for iter in range(max_iters):
|
182 |
+
if iter % eval_interval == 0:
|
183 |
+
losses = estimate_loss(model)
|
184 |
+
print(
|
185 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
186 |
+
)
|
187 |
+
|
188 |
+
xb, yb = get_batch("train")
|
189 |
+
xb, yb = xb.to(device), yb.to(device)
|
190 |
+
|
191 |
+
logits, loss = model(xb, yb)
|
192 |
+
|
193 |
+
optimizer.zero_grad(set_to_none=True)
|
194 |
+
loss.backward()
|
195 |
+
optimizer.step()
|
196 |
+
|
197 |
+
|
198 |
+
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
199 |
+
results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
|
200 |
+
print(results)
|
experiments/exp.ipynb
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"--2023-10-27 16:11:32-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
|
13 |
+
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...\n",
|
14 |
+
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... "
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"connected.\n",
|
22 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
23 |
+
"Length: 1115394 (1.1M) [text/plain]\n",
|
24 |
+
"Saving to: ‘input.txt.1’\n",
|
25 |
+
"\n",
|
26 |
+
"input.txt.1 100%[===================>] 1.06M 734KB/s in 1.5s \n",
|
27 |
+
"\n",
|
28 |
+
"2023-10-27 16:11:36 (734 KB/s) - ‘input.txt.1’ saved [1115394/1115394]\n",
|
29 |
+
"\n"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 2,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"with open(\"input.txt\") as f:\n",
|
44 |
+
" text = f.read()"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": 4,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [
|
52 |
+
{
|
53 |
+
"data": {
|
54 |
+
"text/plain": [
|
55 |
+
"'First Citizen:\\nBefore we proceed any further, hear'"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
"execution_count": 4,
|
59 |
+
"metadata": {},
|
60 |
+
"output_type": "execute_result"
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"source": [
|
64 |
+
"text[:50]"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": 5,
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [
|
72 |
+
{
|
73 |
+
"name": "stdout",
|
74 |
+
"output_type": "stream",
|
75 |
+
"text": [
|
76 |
+
"\n",
|
77 |
+
" !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
|
78 |
+
"65\n"
|
79 |
+
]
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"source": [
|
83 |
+
"chars = sorted(list(set(text)))\n",
|
84 |
+
"vocab_size = len(chars)\n",
|
85 |
+
"\n",
|
86 |
+
"print(\"\".join(chars))\n",
|
87 |
+
"print(vocab_size)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 7,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [
|
95 |
+
{
|
96 |
+
"name": "stdout",
|
97 |
+
"output_type": "stream",
|
98 |
+
"text": [
|
99 |
+
"[46, 47, 1, 58, 46, 43, 56, 43]\n",
|
100 |
+
"hi there\n"
|
101 |
+
]
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"source": [
|
105 |
+
"stoi = {ch: i for i, ch in enumerate(chars)}\n",
|
106 |
+
"itos = {i: ch for i, ch in enumerate(chars)}\n",
|
107 |
+
"\n",
|
108 |
+
"encode = lambda s: [stoi[c] for c in s]\n",
|
109 |
+
"decode = lambda l: \"\".join([itos[i] for i in l])\n",
|
110 |
+
"\n",
|
111 |
+
"print(encode(\"hi there\"))\n",
|
112 |
+
"\n",
|
113 |
+
"print(decode(encode(\"hi there\")))"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": 8,
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [
|
121 |
+
{
|
122 |
+
"name": "stdout",
|
123 |
+
"output_type": "stream",
|
124 |
+
"text": [
|
125 |
+
"torch.Size([1115394]) torch.int64\n",
|
126 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
|
127 |
+
" 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
|
128 |
+
" 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
|
129 |
+
" 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
|
130 |
+
" 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
|
131 |
+
" 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n"
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"import torch\n",
|
137 |
+
"\n",
|
138 |
+
"data = torch.tensor(encode(text), dtype=torch.long)\n",
|
139 |
+
"print(data.shape, data.dtype)\n",
|
140 |
+
"print(data[:100])"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 9,
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"n = int(0.9 * len(data))\n",
|
150 |
+
"train_data = data[:n]\n",
|
151 |
+
"val_data = data[n:]"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 10,
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"name": "stdout",
|
161 |
+
"output_type": "stream",
|
162 |
+
"text": [
|
163 |
+
"Inputs:\n",
|
164 |
+
"torch.Size([4, 8])\n",
|
165 |
+
"tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
|
166 |
+
" [44, 53, 56, 1, 58, 46, 39, 58],\n",
|
167 |
+
" [52, 58, 1, 58, 46, 39, 58, 1],\n",
|
168 |
+
" [25, 17, 27, 10, 0, 21, 1, 54]])\n",
|
169 |
+
"-----------\n",
|
170 |
+
"Targets:\n",
|
171 |
+
"torch.Size([4, 8])\n",
|
172 |
+
"tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
|
173 |
+
" [53, 56, 1, 58, 46, 39, 58, 1],\n",
|
174 |
+
" [58, 1, 58, 46, 39, 58, 1, 46],\n",
|
175 |
+
" [17, 27, 10, 0, 21, 1, 54, 39]])\n"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"source": [
|
180 |
+
"torch.manual_seed(1337)\n",
|
181 |
+
"batch_size = 4\n",
|
182 |
+
"block_size = 8\n",
|
183 |
+
"\n",
|
184 |
+
"\n",
|
185 |
+
"def get_batch(split):\n",
|
186 |
+
" data = train_data if split == \"train\" else val_data\n",
|
187 |
+
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
188 |
+
" x = torch.stack([data[i : i + block_size] for i in ix])\n",
|
189 |
+
" y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])\n",
|
190 |
+
" return x, y\n",
|
191 |
+
"\n",
|
192 |
+
"\n",
|
193 |
+
"xb, yb = get_batch(\"train\")\n",
|
194 |
+
"print(\"Inputs:\")\n",
|
195 |
+
"print(xb.shape)\n",
|
196 |
+
"print(xb)\n",
|
197 |
+
"\n",
|
198 |
+
"print(\"-----------\")\n",
|
199 |
+
"print(\"Targets:\")\n",
|
200 |
+
"print(yb.shape)\n",
|
201 |
+
"print(yb)"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": 11,
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [],
|
209 |
+
"source": [
|
210 |
+
"import torch.nn as nn\n",
|
211 |
+
"from torch.nn import functional as F\n",
|
212 |
+
"\n",
|
213 |
+
"\n",
|
214 |
+
"class BigramLanguageModel(nn.Module):\n",
|
215 |
+
" def __init__(self, vocab_size):\n",
|
216 |
+
" super().__init__()\n",
|
217 |
+
" self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
|
218 |
+
"\n",
|
219 |
+
" def forward(self, idx, targets):\n",
|
220 |
+
" logits = self.token_embedding_table(idx)\n",
|
221 |
+
"\n",
|
222 |
+
" return logits"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"execution_count": 12,
|
228 |
+
"metadata": {},
|
229 |
+
"outputs": [
|
230 |
+
{
|
231 |
+
"name": "stdout",
|
232 |
+
"output_type": "stream",
|
233 |
+
"text": [
|
234 |
+
"torch.Size([4, 8, 65])\n"
|
235 |
+
]
|
236 |
+
}
|
237 |
+
],
|
238 |
+
"source": [
|
239 |
+
"m = BigramLanguageModel(vocab_size)\n",
|
240 |
+
"out = m(xb, yb)\n",
|
241 |
+
"print(out.shape) # B,T,C -> 4X8X65"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 15,
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [
|
249 |
+
{
|
250 |
+
"name": "stdout",
|
251 |
+
"output_type": "stream",
|
252 |
+
"text": [
|
253 |
+
"torch.Size([32, 65])\n",
|
254 |
+
"tensor(4.5262, grad_fn=<NllLossBackward0>)\n"
|
255 |
+
]
|
256 |
+
}
|
257 |
+
],
|
258 |
+
"source": [
|
259 |
+
"class BigramLanguageModel(nn.Module):\n",
|
260 |
+
" def __init__(self, vocab_size):\n",
|
261 |
+
" super().__init__()\n",
|
262 |
+
" self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
|
263 |
+
"\n",
|
264 |
+
" def forward(self, idx, targets=None):\n",
|
265 |
+
" logits = self.token_embedding_table(idx) # BTC\n",
|
266 |
+
" loss = None\n",
|
267 |
+
" if targets is not None:\n",
|
268 |
+
" B, T, C = logits.shape\n",
|
269 |
+
" logits = logits.view(B * T, C)\n",
|
270 |
+
" targets = targets.view(B * T)\n",
|
271 |
+
" loss = F.cross_entropy(logits, targets)\n",
|
272 |
+
" return logits, loss\n",
|
273 |
+
"\n",
|
274 |
+
" def generate(self, idx, max_new_tokens):\n",
|
275 |
+
" for _ in range(max_new_tokens):\n",
|
276 |
+
" logits, loss = self(idx) # BxTxC\n",
|
277 |
+
" logits = logits[:, -1, :] # BxC\n",
|
278 |
+
" probs = F.softmax(logits, dim=-1) # BxC\n",
|
279 |
+
" idx_next = torch.multinomial(probs, num_samples=1) # Bx1\n",
|
280 |
+
" idx = torch.cat((idx, idx_next), dim=1) # BxT+1\n",
|
281 |
+
"\n",
|
282 |
+
" return idx\n",
|
283 |
+
"\n",
|
284 |
+
"\n",
|
285 |
+
"m = BigramLanguageModel(vocab_size)\n",
|
286 |
+
"logits, loss = m(xb, yb)\n",
|
287 |
+
"print(logits.shape) # B,T,C -> 4X8X65\n",
|
288 |
+
"print(loss)"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 16,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"name": "stdout",
|
298 |
+
"output_type": "stream",
|
299 |
+
"text": [
|
300 |
+
"\n",
|
301 |
+
"'JgC.JZWqUkpdtkSpmzjM-,RqzgaN?vC:hgjnAnBZDga-APqGUH!WdCbIb;$DefOYbEvcaKGMmnO'q$KdS-'ZH\n",
|
302 |
+
".YSqr'X!Q! d;\n"
|
303 |
+
]
|
304 |
+
}
|
305 |
+
],
|
306 |
+
"source": [
|
307 |
+
"idx = torch.zeros((1, 1), dtype=torch.long)\n",
|
308 |
+
"\n",
|
309 |
+
"results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
|
310 |
+
"\n",
|
311 |
+
"print(results)"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"cell_type": "code",
|
316 |
+
"execution_count": 17,
|
317 |
+
"metadata": {},
|
318 |
+
"outputs": [],
|
319 |
+
"source": [
|
320 |
+
"optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": 19,
|
326 |
+
"metadata": {},
|
327 |
+
"outputs": [
|
328 |
+
{
|
329 |
+
"name": "stdout",
|
330 |
+
"output_type": "stream",
|
331 |
+
"text": [
|
332 |
+
"2.4206888675689697\n"
|
333 |
+
]
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"batch_size = 32\n",
|
338 |
+
"\n",
|
339 |
+
"for steps in range(10000):\n",
|
340 |
+
" xb, yb = get_batch(\"train\")\n",
|
341 |
+
"\n",
|
342 |
+
" logits, loss = m(xb, yb)\n",
|
343 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
344 |
+
" loss.backward()\n",
|
345 |
+
" optimizer.step()\n",
|
346 |
+
"\n",
|
347 |
+
"print(loss.item())"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": 20,
|
353 |
+
"metadata": {},
|
354 |
+
"outputs": [
|
355 |
+
{
|
356 |
+
"name": "stdout",
|
357 |
+
"output_type": "stream",
|
358 |
+
"text": [
|
359 |
+
"\n",
|
360 |
+
"Hou'sy'ting'stis's w ys'stholealy woawhimedy it 'save,\n",
|
361 |
+
"Too:Had wh fo an, ZCENERUCHENar ee onds, th h\n"
|
362 |
+
]
|
363 |
+
}
|
364 |
+
],
|
365 |
+
"source": [
|
366 |
+
"idx = torch.zeros((1, 1), dtype=torch.long)\n",
|
367 |
+
"\n",
|
368 |
+
"results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
|
369 |
+
"\n",
|
370 |
+
"print(results)"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": 28,
|
376 |
+
"metadata": {},
|
377 |
+
"outputs": [
|
378 |
+
{
|
379 |
+
"data": {
|
380 |
+
"text/plain": [
|
381 |
+
"torch.Size([4, 8, 16])"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
"execution_count": 28,
|
385 |
+
"metadata": {},
|
386 |
+
"output_type": "execute_result"
|
387 |
+
}
|
388 |
+
],
|
389 |
+
"source": [
|
390 |
+
"B, T, C = 4, 8, 32\n",
|
391 |
+
"\n",
|
392 |
+
"x = torch.randn(B, T, C)\n",
|
393 |
+
"\n",
|
394 |
+
"head_size = 16\n",
|
395 |
+
"key = nn.Linear(C, head_size, bias=False)\n",
|
396 |
+
"query = nn.Linear(C, head_size, bias=False)\n",
|
397 |
+
"value = nn.Linear(C, head_size, bias=False)\n",
|
398 |
+
"k = key(x)\n",
|
399 |
+
"q = query(x)\n",
|
400 |
+
"wei = q @ k.transpose(-2, -1) * (head_size**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)\n",
|
401 |
+
"\n",
|
402 |
+
"tril = torch.tril(torch.ones(T, T))\n",
|
403 |
+
"wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
|
404 |
+
"wei = F.softmax(wei, dim=-1)\n",
|
405 |
+
"v = value(x)\n",
|
406 |
+
"out = wei @ v\n",
|
407 |
+
"\n",
|
408 |
+
"out.shape\n"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": 29,
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [
|
416 |
+
{
|
417 |
+
"data": {
|
418 |
+
"text/plain": [
|
419 |
+
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
420 |
+
" [0.3325, 0.6675, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
421 |
+
" [0.3578, 0.2873, 0.3550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
422 |
+
" [0.2281, 0.1964, 0.2733, 0.3022, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
423 |
+
" [0.2851, 0.1588, 0.2068, 0.1436, 0.2057, 0.0000, 0.0000, 0.0000],\n",
|
424 |
+
" [0.2429, 0.1547, 0.1550, 0.1475, 0.2049, 0.0951, 0.0000, 0.0000],\n",
|
425 |
+
" [0.1573, 0.1838, 0.1123, 0.1680, 0.1528, 0.1194, 0.1063, 0.0000],\n",
|
426 |
+
" [0.1139, 0.1704, 0.0766, 0.1134, 0.1600, 0.1466, 0.1228, 0.0963]],\n",
|
427 |
+
" grad_fn=<SelectBackward0>)"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
"execution_count": 29,
|
431 |
+
"metadata": {},
|
432 |
+
"output_type": "execute_result"
|
433 |
+
}
|
434 |
+
],
|
435 |
+
"source": [
|
436 |
+
"wei[0]\n"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": null,
|
442 |
+
"metadata": {},
|
443 |
+
"outputs": [],
|
444 |
+
"source": []
|
445 |
+
}
|
446 |
+
],
|
447 |
+
"metadata": {
|
448 |
+
"kernelspec": {
|
449 |
+
"display_name": "Python 3",
|
450 |
+
"language": "python",
|
451 |
+
"name": "python3"
|
452 |
+
},
|
453 |
+
"language_info": {
|
454 |
+
"codemirror_mode": {
|
455 |
+
"name": "ipython",
|
456 |
+
"version": 3
|
457 |
+
},
|
458 |
+
"file_extension": ".py",
|
459 |
+
"mimetype": "text/x-python",
|
460 |
+
"name": "python",
|
461 |
+
"nbconvert_exporter": "python",
|
462 |
+
"pygments_lexer": "ipython3",
|
463 |
+
"version": "3.10.12"
|
464 |
+
}
|
465 |
+
},
|
466 |
+
"nbformat": 4,
|
467 |
+
"nbformat_minor": 2
|
468 |
+
}
|
gpt.ipynb
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Import Dependencies"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import torch\n"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 2,
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"from src.model import GPTModel\n",
|
26 |
+
"from src.training import train\n",
|
27 |
+
"from src.inference import generate\n",
|
28 |
+
"from src.utils import vocab_size\n"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "markdown",
|
33 |
+
"metadata": {},
|
34 |
+
"source": [
|
35 |
+
"## Decalre Hyperparams"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 3,
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"batch_size = 64\n",
|
45 |
+
"block_size = 256\n",
|
46 |
+
"max_iters = 5000\n",
|
47 |
+
"eval_interval = 500\n",
|
48 |
+
"learning_rate = 3e-4\n",
|
49 |
+
"device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
|
50 |
+
"eval_iters = 200\n",
|
51 |
+
"n_embeds = 384\n",
|
52 |
+
"n_heads = 6\n",
|
53 |
+
"n_layers = 6\n",
|
54 |
+
"dropout = 0.2"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "markdown",
|
59 |
+
"metadata": {},
|
60 |
+
"source": [
|
61 |
+
"## Initialize Model and Optimizer"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 6,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"model = GPTModel(vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device)\n",
|
71 |
+
"model = model.to(device)\n",
|
72 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "markdown",
|
77 |
+
"metadata": {},
|
78 |
+
"source": [
|
79 |
+
"## Model Training"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 7,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stdout",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"Step 0: train loss 4.3249, val loss 4.3219\n",
|
92 |
+
"Step 500: train loss 2.0213, val loss 2.0953\n",
|
93 |
+
"Step 1000: train loss 1.6067, val loss 1.7813\n",
|
94 |
+
"Step 1500: train loss 1.4462, val loss 1.6380\n",
|
95 |
+
"Step 2000: train loss 1.3516, val loss 1.5810\n",
|
96 |
+
"Step 2500: train loss 1.2836, val loss 1.5376\n",
|
97 |
+
"Step 3000: train loss 1.2309, val loss 1.5148\n",
|
98 |
+
"Step 3500: train loss 1.1910, val loss 1.4904\n",
|
99 |
+
"Step 4000: train loss 1.1522, val loss 1.4822\n",
|
100 |
+
"Step 4500: train loss 1.1186, val loss 1.4838\n"
|
101 |
+
]
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"source": [
|
105 |
+
"train(\n",
|
106 |
+
" model,\n",
|
107 |
+
" optimizer,\n",
|
108 |
+
" max_iters,\n",
|
109 |
+
" eval_interval,\n",
|
110 |
+
" eval_iters,\n",
|
111 |
+
" block_size,\n",
|
112 |
+
" batch_size,\n",
|
113 |
+
" device,\n",
|
114 |
+
")\n"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "markdown",
|
119 |
+
"metadata": {},
|
120 |
+
"source": [
|
121 |
+
"## Load the model and Generate text"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": 4,
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [
|
129 |
+
{
|
130 |
+
"name": "stdout",
|
131 |
+
"output_type": "stream",
|
132 |
+
"text": [
|
133 |
+
"hellows thence grown from thee.\n",
|
134 |
+
"Since thou hast raim, thou thast well were quarterned; and\n",
|
135 |
+
"ever man tree can saw for words word from her at hour\n",
|
136 |
+
"Whiles contrations or devoided from ere years;\n",
|
137 |
+
"Yea, foul vice, indelice on the bird of the\n",
|
138 |
+
"noble of Hermione.\n",
|
139 |
+
"\n",
|
140 |
+
"PARIS:\n",
|
141 |
+
"Sir, adies, sir, hate no choping but to your good.\n",
|
142 |
+
"\n",
|
143 |
+
"HENRY BOLINGBROKE:\n",
|
144 |
+
"Yes, to ask you might, foreweed.\n",
|
145 |
+
"\n",
|
146 |
+
"WARCK:\n",
|
147 |
+
"'Tis he made moust true.\n",
|
148 |
+
"\n",
|
149 |
+
"RORSET:\n",
|
150 |
+
"It is an hour fastal that cracknaf at the chase\n",
|
151 |
+
"Upon; you are your hearing news a daughter.\n",
|
152 |
+
"\n",
|
153 |
+
"KING EDWARD IV:\n",
|
154 |
+
"Tut, Lord Warwick, thou shouldst aft Rutlansps?\n",
|
155 |
+
"Thou tust but back hild, he countemn'd my lady's seal,\n",
|
156 |
+
"For access dead the treature moon! and the Englisting!\n",
|
157 |
+
"Thy vage for yonder see thou be donen?\n",
|
158 |
+
"O, count thou dost not Romeo, thou pratheeo sir,\n",
|
159 |
+
"That sweet thou feigh with no past blood on\n",
|
160 |
+
"Be see, here through on that find bears, if an\n",
|
161 |
+
"pretterinctors three and aspect die meeds thou,\n",
|
162 |
+
"Behing mine of thy denigning state lain business?\n",
|
163 |
+
"\n",
|
164 |
+
"SAMPSA:\n",
|
165 |
+
"Sir, ha! but thou refused? thyself food, gr\n"
|
166 |
+
]
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"source": [
|
170 |
+
"model = torch.load(\"checkpoints/model.pth\", map_location={\"cpu\": device})\n",
|
171 |
+
"results = generate(\"hello\", model, block_size, 1000, device)\n",
|
172 |
+
"print(results)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": []
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [],
|
187 |
+
"source": []
|
188 |
+
}
|
189 |
+
],
|
190 |
+
"metadata": {
|
191 |
+
"kernelspec": {
|
192 |
+
"display_name": "Python 3",
|
193 |
+
"language": "python",
|
194 |
+
"name": "python3"
|
195 |
+
},
|
196 |
+
"language_info": {
|
197 |
+
"codemirror_mode": {
|
198 |
+
"name": "ipython",
|
199 |
+
"version": 3
|
200 |
+
},
|
201 |
+
"file_extension": ".py",
|
202 |
+
"mimetype": "text/x-python",
|
203 |
+
"name": "python",
|
204 |
+
"nbconvert_exporter": "python",
|
205 |
+
"pygments_lexer": "ipython3",
|
206 |
+
"version": "3.10.12"
|
207 |
+
}
|
208 |
+
},
|
209 |
+
"nbformat": 4,
|
210 |
+
"nbformat_minor": 2
|
211 |
+
}
|
src/__pycache__/inference.cpython-310.pyc
ADDED
Binary file (527 Bytes). View file
|
|
src/__pycache__/model.cpython-310.pyc
ADDED
Binary file (4.77 kB). View file
|
|
src/__pycache__/training.cpython-310.pyc
ADDED
Binary file (1.27 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.75 kB). View file
|
|
src/inference.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from src.utils import encode, decode
|
3 |
+
|
4 |
+
|
5 |
+
def generate(prompt, model, block_size, max_new_tokens, device):
|
6 |
+
X = torch.tensor(encode(prompt), dtype=torch.long, device=device)
|
7 |
+
X = X[:block_size].unsqueeze(0)
|
8 |
+
results = decode(model.generate(X, max_new_tokens=max_new_tokens)[0].tolist())
|
9 |
+
return results
|
src/model.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Head(nn.Module):
|
7 |
+
def __init__(self, n_embeds, head_size, block_size, dropout) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.key = nn.Linear(n_embeds, head_size, bias=False)
|
10 |
+
self.query = nn.Linear(n_embeds, head_size, bias=False)
|
11 |
+
self.value = nn.Linear(n_embeds, head_size, bias=False)
|
12 |
+
self.dropout = nn.Dropout(dropout)
|
13 |
+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
B, T, C = x.shape
|
17 |
+
k = self.key(x)
|
18 |
+
q = self.query(x)
|
19 |
+
wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
|
20 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
21 |
+
wei = F.softmax(wei, dim=-1)
|
22 |
+
wei = self.dropout(wei)
|
23 |
+
v = self.value(x)
|
24 |
+
out = wei @ v
|
25 |
+
return out
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadAttention(nn.Module):
|
29 |
+
def __init__(self, n_heads, n_embeds, head_size, block_size, dropout):
|
30 |
+
super().__init__()
|
31 |
+
self.heads = nn.ModuleList(
|
32 |
+
[Head(n_embeds, head_size, block_size, dropout) for _ in range(n_heads)]
|
33 |
+
)
|
34 |
+
self.proj = nn.Linear(n_embeds, n_embeds)
|
35 |
+
self.dropout = nn.Dropout(dropout)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = torch.cat([h(x) for h in self.heads], dim=-1)
|
39 |
+
x = self.proj(x)
|
40 |
+
x = self.dropout(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class FeedForward(nn.Module):
|
45 |
+
def __init__(self, n_embeds, dropout):
|
46 |
+
super().__init__()
|
47 |
+
self.net = nn.Sequential(
|
48 |
+
nn.Linear(n_embeds, 4 * n_embeds),
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.Linear(4 * n_embeds, n_embeds),
|
51 |
+
nn.Dropout(dropout),
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.net(x)
|
56 |
+
|
57 |
+
|
58 |
+
class Decoder(nn.Module):
|
59 |
+
def __init__(self, n_embeds, n_heads, block_size, dropout):
|
60 |
+
super().__init__()
|
61 |
+
head_size = n_embeds // n_heads
|
62 |
+
self.sa_heads = MultiHeadAttention(
|
63 |
+
n_heads, n_embeds, head_size, block_size, dropout
|
64 |
+
)
|
65 |
+
self.ffwd = FeedForward(n_embeds, dropout)
|
66 |
+
self.ln1 = nn.LayerNorm(n_embeds)
|
67 |
+
self.ln2 = nn.LayerNorm(n_embeds)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = x + self.sa_heads(self.ln1(x))
|
71 |
+
x = x + self.ffwd(self.ln2(x))
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class GPTModel(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self, vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
self.device = device
|
81 |
+
self.block_size = block_size
|
82 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
|
83 |
+
self.position_embedding_table = nn.Embedding(block_size, n_embeds)
|
84 |
+
self.blocks = nn.Sequential(
|
85 |
+
*[Decoder(n_embeds, n_heads, block_size, dropout) for _ in range(n_layers)]
|
86 |
+
)
|
87 |
+
self.lnf = nn.LayerNorm(n_embeds)
|
88 |
+
self.lm_head = nn.Linear(n_embeds, vocab_size)
|
89 |
+
|
90 |
+
def forward(self, idx, targets=None):
|
91 |
+
B, T = idx.shape
|
92 |
+
|
93 |
+
tok_embeds = self.token_embedding_table(idx) # BxTxNemb
|
94 |
+
pos_embeds = self.position_embedding_table(
|
95 |
+
torch.arange(T, device=self.device)
|
96 |
+
) # TXNemb
|
97 |
+
|
98 |
+
x = tok_embeds + pos_embeds # BxTxNemb
|
99 |
+
x = self.blocks(x)
|
100 |
+
x = self.lnf(x)
|
101 |
+
logits = self.lm_head(x) # BxTxVocabSize
|
102 |
+
|
103 |
+
loss = None
|
104 |
+
if targets is not None:
|
105 |
+
B, T, C = logits.shape
|
106 |
+
logits = logits.view(B * T, C)
|
107 |
+
targets = targets.view(B * T)
|
108 |
+
loss = F.cross_entropy(logits, targets)
|
109 |
+
return logits, loss
|
110 |
+
|
111 |
+
def generate(self, idx, max_new_tokens):
|
112 |
+
for _ in range(max_new_tokens):
|
113 |
+
idx_cond = idx[:, -self.block_size :]
|
114 |
+
logits, loss = self(idx_cond) # BxTxC
|
115 |
+
logits = logits[:, -1, :] # BxC
|
116 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
117 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
118 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
119 |
+
|
120 |
+
return idx
|
src/training.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from src.utils import get_batch
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device):
|
9 |
+
out = {}
|
10 |
+
model.eval()
|
11 |
+
for split in ["train", "val"]:
|
12 |
+
losses = torch.zeros(eval_iters)
|
13 |
+
for k in range(eval_iters):
|
14 |
+
X, Y = get_batch(split, block_size, batch_size)
|
15 |
+
X, Y = X.to(device), Y.to(device)
|
16 |
+
logits, loss = model(X, Y)
|
17 |
+
losses[k] = loss.item()
|
18 |
+
out[split] = losses.mean()
|
19 |
+
model.train()
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
def train(
|
24 |
+
model,
|
25 |
+
optimizer,
|
26 |
+
max_iters,
|
27 |
+
eval_interval,
|
28 |
+
eval_iters,
|
29 |
+
block_size,
|
30 |
+
batch_size,
|
31 |
+
device,
|
32 |
+
):
|
33 |
+
val_loss = None
|
34 |
+
for iter in range(max_iters):
|
35 |
+
if iter % eval_interval == 0:
|
36 |
+
losses = estimate_loss(model, eval_iters, block_size, batch_size, device)
|
37 |
+
print(
|
38 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
39 |
+
)
|
40 |
+
if val_loss is not None:
|
41 |
+
if losses["val"] < val_loss:
|
42 |
+
torch.save(model, "checkpoints/model.pth")
|
43 |
+
else:
|
44 |
+
val_loss = losses["val"]
|
45 |
+
|
46 |
+
xb, yb = get_batch("train", block_size, batch_size)
|
47 |
+
xb, yb = xb.to(device), yb.to(device)
|
48 |
+
|
49 |
+
logits, loss = model(xb, yb)
|
50 |
+
|
51 |
+
optimizer.zero_grad(set_to_none=True)
|
52 |
+
loss.backward()
|
53 |
+
optimizer.step()
|
src/utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
with open("data/input.txt") as f:
|
4 |
+
text = f.read()
|
5 |
+
|
6 |
+
chars = sorted(list(set(text)))
|
7 |
+
vocab_size = len(chars)
|
8 |
+
|
9 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
10 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
11 |
+
|
12 |
+
|
13 |
+
def encode(s):
|
14 |
+
return [stoi[c] for c in s]
|
15 |
+
|
16 |
+
|
17 |
+
def decode(l):
|
18 |
+
return "".join([itos[i] for i in l])
|
19 |
+
|
20 |
+
|
21 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
22 |
+
n = int(0.9 * len(data))
|
23 |
+
train_data = data[:n]
|
24 |
+
val_data = data[n:]
|
25 |
+
|
26 |
+
|
27 |
+
def get_batch(split, block_size, batch_size):
|
28 |
+
data = train_data if split == "train" else val_data
|
29 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
30 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
31 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
32 |
+
return x, y
|