909ahmed commited on
Commit
4bd0020
1 Parent(s): 3d1e7d1

adsfadf

Files changed (1) hide show
  1. app.py +103 -175
app.py CHANGED
@@ -1,181 +1,109 @@
1
  import gradio as gr
 
 
 
2
 
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- n_embd = 64
8
- dropout = 0.0
9
- block_size = 32
10
- vocab_size = 65
11
- n_head = 4
12
- n_layer = 4
13
-
14
- class Head(nn.Module):
15
-
16
- def __init__(self, head_size):
17
- super().__init__()
18
- self.key = nn.Linear(n_embd, head_size, bias=False)
19
- self.query = nn.Linear(n_embd, head_size, bias=False)
20
- self.value = nn.Linear(n_embd, head_size, bias=False)
21
- self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
22
-
23
- self.dropout = nn.Dropout(dropout)
24
-
25
- def forward(self, x):
26
- B,T,C = x.shape
27
- k = self.key(x)
28
- q = self.query(x)
29
- wei = q @ k.transpose(-2,-1) * C**-0.5
30
- wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
31
- wei = F.softmax(wei, dim=-1)
32
- wei = self.dropout(wei)
33
-
34
- v = self.value(x)
35
- out = wei @ v
36
- return out
37
-
38
- class MultiHeadAttention(nn.Module):
39
-
40
- def __init__(self, num_heads, head_size):
41
- super().__init__()
42
- self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
43
- self.proj = nn.Linear(n_embd, n_embd)
44
- self.dropout = nn.Dropout(dropout)
45
-
46
- def forward(self, x):
47
- out = torch.cat([h(x) for h in self.heads], dim=-1)
48
- out = self.dropout(self.proj(out))
49
- return out
50
-
51
- class FeedFoward(nn.Module):
52
-
53
- def __init__(self, n_embd):
54
- super().__init__()
55
- self.net = nn.Sequential(
56
- nn.Linear(n_embd, 4 * n_embd),
57
- nn.ReLU(),
58
- nn.Linear(4 * n_embd, n_embd),
59
- nn.Dropout(dropout),
60
- )
61
-
62
- def forward(self, x):
63
- return self.net(x)
64
-
65
- class Block(nn.Module):
66
-
67
- def __init__(self, n_embd, n_head):
68
- super().__init__()
69
- head_size = n_embd // n_head
70
- self.sa = MultiHeadAttention(n_head, head_size)
71
- self.ffwd = FeedFoward(n_embd)
72
- self.ln1 = nn.LayerNorm(n_embd)
73
- self.ln2 = nn.LayerNorm(n_embd)
74
-
75
- def forward(self, x):
76
- x = x + self.sa(self.ln1(x))
77
- x = x + self.ffwd(self.ln2(x))
78
- return x
79
-
80
- class BigramLanguageModel(nn.Module):
81
-
82
  def __init__(self):
83
- super().__init__()
84
- self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
85
- self.position_embedding_table = nn.Embedding(block_size, n_embd)
86
- self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
87
- self.ln_f = nn.LayerNorm(n_embd)
88
- self.lm_head = nn.Linear(n_embd, vocab_size)
89
-
90
- def forward(self, idx, targets=None):
91
- B, T = idx.shape
92
-
93
- tok_emb = self.token_embedding_table(idx)
94
- pos_emb = self.position_embedding_table(torch.arange(T))
95
- x = tok_emb + pos_emb
96
- x = self.blocks(x)
97
- x = self.ln_f(x)
98
- logits = self.lm_head(x)
99
-
100
- if targets is None:
101
- loss = None
102
- else:
103
- B, T, C = logits.shape
104
- logits = logits.view(B*T, C)
105
- targets = targets.view(B*T)
106
- loss = F.cross_entropy(logits, targets)
107
-
108
- return logits, loss
109
-
110
- def generate(self, idx, max_new_tokens):
111
- for _ in range(max_new_tokens):
112
-
113
- idx_cond = idx[:, -block_size:]
114
- logits, loss = self(idx_cond)
115
- logits = logits[:, -1, :]
116
- probs = F.softmax(logits, dim=-1)
117
- idx_next = torch.multinomial(probs, num_samples=1)
118
- idx = torch.cat((idx, idx_next), dim=1)
119
-
120
- return idx
121
-
122
-
123
- chars = "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
124
- itos = { i:ch for i,ch in enumerate(chars) }
125
- stoi = { ch:i for i,ch in enumerate(chars) }
126
-
127
- decode = lambda l: ''.join([itos[i] for i in l])
128
- encode = lambda s: [stoi[c] for c in s]
129
-
130
- model = BigramLanguageModel()
131
-
132
- state_model = torch.load("output", map_location=torch.device('cpu'))
133
- # state_dict = state_model.state_dict()
134
-
135
- model.load_state_dict(state_model, strict=False)
136
-
137
- def respond(
138
- message,
139
- history: list[tuple[str, str]],
140
- ):
141
- messages = [{"role": "system", "content": "Cocaine"}]
142
-
143
- for val in history:
144
- if val[0]:
145
- messages.append({"role": "user", "content": val[0]})
146
- if val[1]:
147
- messages.append({"role": "assistant", "content": val[1]})
148
-
149
- messages.append({"role": "user", "content": message})
150
-
151
- response = ""
152
- yield response
153
-
154
- input_txt = encode(message)
155
- context = torch.tensor(input_txt).unsqueeze(0)
156
-
157
- idx = context
158
- result = ""
159
- for _ in range(500):
160
-
161
- idx_cond = idx[:, -block_size:]
162
- logits, loss = model(idx_cond)
163
- logits = logits[:, -1, :]
164
- probs = F.softmax(logits, dim=-1)
165
- idx_next = torch.multinomial(probs, num_samples=1)
166
- idx = torch.cat((idx, idx_next), dim=1)
167
-
168
- # yield "I need drugs"
169
- result += decode(idx_next[0].tolist())
170
- yield result
171
-
172
- demo = gr.ChatInterface(
173
 
174
- respond,
175
- title="Sherlock doing meth again?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- )
178
-
 
 
179
 
180
- if __name__ == "__main__":
181
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import regex as re
3
+ from tqdm import tqdm
4
+ import pickle
5
 
6
+ class Tokenizer:
7
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ self.vocab = {idx : bytes([idx]) for idx in range(256)}
11
+ self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
12
+ self.merges = {}
13
+
14
+ def merge(self, tokens, target, new_token):
15
+
16
+ new_tokens = []
17
+ i = 0
18
+ while i < len(tokens):
19
+
20
+ if i + 1 < len(tokens) and tokens[i] == target[0] and tokens[i + 1] == target[1]:
21
+ i += 1
22
+ new_tokens.append(new_token)
23
+ else:
24
+ new_tokens.append(tokens[i])
25
+ i += 1
26
+
27
+ return new_tokens
28
+
29
+ def get_stats(self, idsList):
30
+
31
+ pairs = {}
32
+ if not isinstance(idsList[0], list):
33
+ idsList = [idsList]
34
+ for tokens in idsList:
35
+ for a, b in zip(tokens, tokens[1:]):
36
+
37
+ if not (a, b) in pairs:
38
+ pairs[(a, b)] = 1
39
+ else:
40
+ pairs[(a, b)] += 1
41
+ return pairs
42
 
43
+ def get_max_pair(self, idsList):
44
+
45
+ pairs = self.get_stats(idsList)
46
+ return sorted(pairs.items(), key=lambda item : item[1])[-1][0]
47
 
48
+ def get_min(self, idsList):
49
+
50
+ stats = self.get_stats(idsList)
51
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
52
+ return pair
53
+
54
+ def train(self, epochs, text):
55
+
56
+ pat = re.compile(self.pattern)
57
+ textList = re.findall(pat, text)
58
+ idsList = [list(text.encode('utf-8')) for text in textList]
59
+ for epoch in tqdm(range(epochs)):
60
+
61
+ max_pair = self.get_max_pair(idsList)
62
+ new_token = 256 + epoch
63
+ self.merges[max_pair] = new_token
64
+ idsList = [self.merge(tokens, max_pair, new_token) for tokens in idsList]
65
+ self.vocab[new_token] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
66
+
67
+ return [x for xs in idsList for x in xs]
68
+
69
+ def encode(self, text):
70
+
71
+ tokens = list(text.encode('utf-8'))
72
+ while len(tokens) >= 2:
73
+
74
+ pair = self.get_min(tokens)
75
+ if pair not in self.merges:
76
+ break
77
+
78
+ idx = self.merges[pair]
79
+ tokens = self.merge(tokens, pair, idx)
80
+
81
+ return tokens
82
+
83
+ def decode(self, tokens):
84
+
85
+ tokens = b"".join(self.vocab[token] for token in tokens)
86
+ text = tokens.decode('utf-8', errors='replace')
87
+ return text
88
+
89
+ title = "Ghalib doing tiktok"
90
+ description = "A simple Gradio interface to infer urdu tokenizer"
91
+
92
+ tokenizer = Tokenizer()
93
+ with open('merges.pkl', 'rb') as files:
94
+ tokenizer.vocab = pickle.load(files)
95
+ with open('vocab.pkl', 'rb') as files:
96
+ tokenizer.merges = pickle.load(files)
97
+
98
+ def inference(text):
99
+ return tokenizer.encode(text)
100
+
101
+ iface = gr.Interface(
102
+ inference,
103
+ inputs = ["text"],
104
+ outputs = ["text"],
105
+ title = title,
106
+ description = description,
107
+ )
108
+
109
+ iface.launch()