eliot commited on
Commit
787be42
1 Parent(s): 7048965

training file and weights

Browse files
Files changed (2) hide show
  1. bigram.py +172 -0
  2. transformer.pth +3 -0
bigram.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import argparse
4
+ from torch.nn import functional as F
5
+ import time
6
+ from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock
7
+ torch.manual_seed(1337)
8
+
9
+ def get_batch(batch_size, dataset, block_size):
10
+ sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1))
11
+ xb = torch.zeros(batch_size,block_size, dtype=torch.long)
12
+ yb = torch.zeros(batch_size,block_size, dtype=torch.long)
13
+ for idx, sample_index in enumerate(sample):
14
+ xb[idx,:] = dataset[sample_index:sample_index+block_size]
15
+ yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1]
16
+ return xb, yb
17
+
18
+ @torch.no_grad()
19
+ def eval(model, batch_size, block_size, dataset):
20
+ xb, yb = get_batch(batch_size, dataset, block_size)
21
+ logits, loss = model(xb, yb)
22
+ return loss.item()
23
+
24
+ def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps):
25
+ sumloss = 0
26
+ for _ in range(1,steps+1):
27
+ xb, yb = get_batch(batch_size, train_ds, block_size)
28
+ logits, loss = model(xb, yb)
29
+ sumloss += loss.item()
30
+ optimizer.zero_grad(set_to_none=True)
31
+ loss.backward()
32
+ optimizer.step()
33
+ if _ % 1000 == 0:
34
+ val_loss = eval(model, 30, block_size, val_ds,)
35
+ print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}")
36
+
37
+ sumloss = 0
38
+
39
+ class Transformer(torch.nn.Module):
40
+ def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None:
41
+ super().__init__()
42
+ self.block_size=block_size
43
+ self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
44
+ self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
45
+ self.tf_blocks = torch.nn.Sequential(
46
+ *[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)]
47
+ )
48
+ self.lm_head = torch.nn.Linear(128, vocab_size)
49
+ def forward(self, idx, targets=None):
50
+ B,T=idx.shape
51
+ token_embed = self.token_embedding_table(idx)
52
+ positional_embed = self.positional_embedding(torch.arange(T))
53
+ x = token_embed+positional_embed
54
+ x= self.tf_blocks(x)
55
+ logits = self.lm_head(x)
56
+
57
+ if targets is None:
58
+ loss = None
59
+ else:
60
+ B, T, C = logits.shape
61
+ logits = logits.view(B*T, C)
62
+ targets = targets.view(B*T)
63
+ loss = F.cross_entropy(logits, targets)
64
+ return logits, loss
65
+ def generate(self, idx, max_new_tokens):
66
+ # idx is (B, T) array of indices in the current context
67
+ for _ in range(max_new_tokens):
68
+ # get the predictions
69
+ logits, loss = self(idx[:, -self.block_size:])
70
+ # focus only on the last time step
71
+ logits = logits[:, -1, :] # becomes (B, C)
72
+ # apply softmax to get probabilities
73
+ probs = F.softmax(logits, dim=-1) # (B, C)
74
+ # sample from the distribution
75
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
76
+ # append sampled index to the running sequence
77
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
78
+ return idx
79
+ class BigramLanguageModel(torch.nn.Module):
80
+ def __init__(self, vocab_size,block_size=8,token_embed_dim=16):
81
+ super().__init__()
82
+ self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
83
+ self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
84
+ self.attention_head = MultiHeadAttention(n_embed=token_embed_dim,
85
+ timesteps=block_size,
86
+ head_size=token_embed_dim//4, # does head size have to == token embed_dim / n heads? I think it does
87
+ n_heads=4) # (in = (B, T, C), out = B,T,C)
88
+ self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) # (in B, T, C, out = B, T, C, performs linear on C)
89
+ self.block_size = block_size
90
+ def forward(self, idx, targets=None):
91
+ B, T = idx.shape
92
+ # idx and targets are both (B,T) tensor of integers
93
+ token_embedding = self.token_embedding_table(idx) # (B,T, in), (B,T,embed_dim out)
94
+ positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) # (T, embed_dim)
95
+ x = token_embedding + positional_embedding # (B,T,embed_dim)
96
+ x = self.attention_head(x) # (B,T,embed_dim)
97
+ logits = self.lm_head(x)
98
+ if targets is None:
99
+ loss = None
100
+ else:
101
+ B, T, C = logits.shape
102
+ logits = logits.view(B*T, C)
103
+ targets = targets.view(B*T)
104
+ loss = F.cross_entropy(logits, targets)
105
+ return logits, loss
106
+
107
+ def generate(self, idx, max_new_tokens):
108
+ # idx is (B, T) array of indices in the current context
109
+ for _ in range(max_new_tokens):
110
+ # get the predictions
111
+ logits, loss = self(idx[:, -self.block_size:])
112
+ # focus only on the last time step
113
+ logits = logits[:, -1, :] # becomes (B, C)
114
+ # apply softmax to get probabilities
115
+ probs = F.softmax(logits, dim=-1) # (B, C)
116
+ # sample from the distribution
117
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
118
+ # append sampled index to the running sequence
119
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
120
+ return idx
121
+ def main():
122
+ ########################
123
+ #PARAMS#################
124
+ batch_size = 32
125
+ block_size= 128
126
+ n_embed = 128
127
+ n_tf = 3
128
+ n_heads=8
129
+ head_size=16
130
+ vocab_size=65
131
+ ########################
132
+ parser = argparse.ArgumentParser(
133
+ description='Train a bigram language model'
134
+ )
135
+ parser.add_argument('-c', '--cont', action='store_true',)
136
+ parser.add_argument('-e', '--eval', action='store_true',)
137
+ parser.add_argument('-v', '--verbose',action='store_true')
138
+ text = open('input.txt').read()
139
+ characters = sorted(list(set(text)))
140
+ decoder = dict(enumerate(characters))
141
+ encoder = {v: k for k, v in decoder.items()}
142
+ encode = lambda x: encoder[x]
143
+ decode = lambda x: decoder[x]
144
+ text_tensor = torch.tensor([encode(c) for c in text])
145
+ train_tensor = text_tensor[:int(len(text_tensor) * 0.8)]
146
+ val_tensor = text_tensor[int(len(text_tensor) * 0.8):]
147
+ model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed)
148
+ if parser.parse_args().verbose:
149
+ print(model)
150
+ num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
+ print('parameters:', num_params)
152
+ # if -c is passed we will load the model from the file
153
+ if parser.parse_args().cont:
154
+ state_dict = torch.load('transformer.pth')
155
+ model.load_state_dict(state_dict)
156
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
157
+ s = time.time()
158
+ if not parser.parse_args().eval:
159
+ try:
160
+ train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000)
161
+ except KeyboardInterrupt:
162
+ torch.save(model.state_dict(), 'transformer.pth')
163
+ exit()
164
+ if parser.parse_args().verbose:
165
+ print('training time: ', time.time() - s)
166
+ torch.save(model.state_dict(), 'transformer.pth')
167
+ model.eval()
168
+ print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]]))
169
+ # 2.57 adam
170
+ if __name__ == '__main__':
171
+ main()
172
+
transformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b89e10dd4ab50a8ae82d6340d7cddc6c4953035194a30c871c0ea9ee90ab0848
3
+ size 2543221