vincentiusyoshuac commited on
Commit
d2a9674
·
verified ·
1 Parent(s): 2f72c6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -0
app.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import math
6
+ from transformers import PreTrainedTokenizerFast
7
+ import numpy as np
8
+ from typing import Optional, List, Dict
9
+
10
+ class NeuromodulatedAttention(nn.Module):
11
+ def __init__(self, d_model: int, num_heads: int):
12
+ super().__init__()
13
+ self.d_model = d_model
14
+ self.num_heads = num_heads
15
+ self.head_dim = d_model // num_heads
16
+
17
+ self.qkv = nn.Linear(d_model, 3 * d_model)
18
+ self.out_proj = nn.Linear(d_model, d_model)
19
+
20
+ # Neuromodulation
21
+ self.dopamine_gate = nn.Linear(d_model, num_heads)
22
+ self.serotonin_gate = nn.Linear(d_model, num_heads)
23
+ self.memory_decay = nn.Parameter(torch.ones(num_heads) * 0.99)
24
+ self.forget_gate = nn.Linear(d_model, num_heads)
25
+ self.attention_mask = nn.Parameter(torch.ones(num_heads))
26
+
27
+ # Memory
28
+ self.register_buffer('memory_state', torch.zeros(1, num_heads, 1, self.head_dim))
29
+
30
+ def update_memory(self, new_info: torch.Tensor, dopamine: torch.Tensor, forget: torch.Tensor):
31
+ self.memory_state = (
32
+ self.memory_state * self.memory_decay.view(1, -1, 1, 1) *
33
+ (1 - forget.unsqueeze(-1)) +
34
+ dopamine.unsqueeze(-1) * new_info
35
+ )
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ batch_size, seq_length, _ = x.shape
39
+
40
+ # Neuromodulators
41
+ dopamine = torch.sigmoid(self.dopamine_gate(x.mean(dim=1)))
42
+ serotonin = torch.sigmoid(self.serotonin_gate(x.mean(dim=1)))
43
+ forget = torch.sigmoid(self.forget_gate(x.mean(dim=1)))
44
+
45
+ # Attention computation
46
+ qkv = self.qkv(x)
47
+ qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
48
+ qkv = qkv.permute(2, 0, 3, 1, 4)
49
+ q, k, v = qkv[0], qkv[1], qkv[2]
50
+
51
+ # Include memory
52
+ k = torch.cat([k, self.memory_state.expand(batch_size, -1, -1, -1)], dim=2)
53
+ v = torch.cat([v, self.memory_state.expand(batch_size, -1, -1, -1)], dim=2)
54
+
55
+ # Attention with neuromodulation
56
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
57
+ scores = scores * serotonin.view(batch_size, self.num_heads, 1, 1)
58
+ scores = scores * self.attention_mask.view(1, -1, 1, 1)
59
+
60
+ attention = F.softmax(scores, dim=-1)
61
+ x = torch.matmul(attention, v)
62
+
63
+ # Update memory
64
+ self.update_memory(x.mean(dim=2), dopamine, forget)
65
+
66
+ x = x.transpose(1, 2).reshape(batch_size, seq_length, self.d_model)
67
+ return self.out_proj(x)
68
+
69
+ class TransformerBlock(nn.Module):
70
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
71
+ super().__init__()
72
+ self.attention = NeuromodulatedAttention(d_model, num_heads)
73
+ self.norm1 = nn.LayerNorm(d_model)
74
+ self.norm2 = nn.LayerNorm(d_model)
75
+ self.ff = nn.Sequential(
76
+ nn.Linear(d_model, d_ff),
77
+ nn.ReLU(),
78
+ nn.Linear(d_ff, d_model),
79
+ nn.Dropout(dropout)
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ x = x + self.attention(self.norm1(x))
84
+ x = x + self.ff(self.norm2(x))
85
+ return x
86
+
87
+ class NeuroTransformer(nn.Module):
88
+ def __init__(
89
+ self,
90
+ vocab_size: int,
91
+ d_model: int = 256,
92
+ num_heads: int = 4,
93
+ num_layers: int = 3,
94
+ d_ff: int = 512,
95
+ dropout: float = 0.1,
96
+ max_seq_length: int = 128
97
+ ):
98
+ super().__init__()
99
+ self.d_model = d_model
100
+ self.embedding = nn.Embedding(vocab_size, d_model)
101
+ self.pos_encoding = self._create_positional_encoding(max_seq_length, d_model)
102
+ self.layers = nn.ModuleList([
103
+ TransformerBlock(d_model, num_heads, d_ff, dropout)
104
+ for _ in range(num_layers)
105
+ ])
106
+ self.final_layer = nn.Linear(d_model, vocab_size)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ def _create_positional_encoding(self, max_seq_length: int, d_model: int) -> torch.Tensor:
110
+ pos_encoding = torch.zeros(max_seq_length, d_model)
111
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
112
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
113
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
114
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
115
+ return pos_encoding.unsqueeze(0)
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ x = self.embedding(x) * math.sqrt(self.d_model)
119
+ x = x + self.pos_encoding[:, :x.size(1)].to(x.device)
120
+ x = self.dropout(x)
121
+
122
+ for layer in self.layers:
123
+ x = layer(x)
124
+
125
+ return self.final_layer(x)
126
+
127
+ def generate(
128
+ self,
129
+ tokenizer: PreTrainedTokenizerFast,
130
+ prompt: str,
131
+ max_length: int = 100,
132
+ temperature: float = 0.7,
133
+ top_k: int = 50,
134
+ top_p: float = 0.9
135
+ ) -> str:
136
+ self.eval()
137
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
138
+
139
+ with torch.no_grad():
140
+ for _ in range(max_length):
141
+ outputs = self(input_ids)
142
+ next_token_logits = outputs[:, -1, :] / temperature
143
+
144
+ # Top-k
145
+ if top_k > 0:
146
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
147
+ next_token_logits[indices_to_remove] = float('-inf')
148
+
149
+ # Top-p
150
+ if top_p < 1.0:
151
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
152
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
153
+ sorted_indices_to_remove = cumulative_probs > top_p
154
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
155
+ sorted_indices_to_remove[..., 0] = 0
156
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
157
+ next_token_logits[indices_to_remove] = float('-inf')
158
+
159
+ probs = F.softmax(next_token_logits, dim=-1)
160
+ next_token = torch.multinomial(probs, num_samples=1)
161
+
162
+ if next_token.item() == tokenizer.eos_token_id:
163
+ break
164
+
165
+ input_ids = torch.cat([input_ids, next_token], dim=1)
166
+
167
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
168
+
169
+ class TextGenerator:
170
+ def __init__(self):
171
+ self.tokenizer = PreTrainedTokenizerFast.from_pretrained('gpt2')
172
+ self.model = NeuroTransformer(vocab_size=self.tokenizer.vocab_size)
173
+
174
+ def train_on_text(
175
+ self,
176
+ text: str,
177
+ epochs: int,
178
+ learning_rate: float,
179
+ batch_size: int,
180
+ progress=gr.Progress()
181
+ ) -> str:
182
+ encodings = self.tokenizer(text, truncation=True, padding=True, return_tensors="pt")
183
+ input_ids = encodings['input_ids']
184
+
185
+ dataset = torch.utils.data.TensorDataset(input_ids, input_ids)
186
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
187
+
188
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
189
+ criterion = nn.CrossEntropyLoss()
190
+
191
+ logs = []
192
+ self.model.train()
193
+
194
+ for epoch in progress.tqdm(range(epochs)):
195
+ total_loss = 0
196
+ for batch in dataloader:
197
+ optimizer.zero_grad()
198
+ input_ids, labels = batch
199
+ outputs = self.model(input_ids)
200
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
201
+ loss.backward()
202
+ optimizer.step()
203
+ total_loss += loss.item()
204
+
205
+ avg_loss = total_loss / len(dataloader)
206
+ logs.append(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
207
+
208
+ return "\n".join(logs)
209
+
210
+ def generate(
211
+ self,
212
+ prompt: str,
213
+ max_length: int,
214
+ temperature: float,
215
+ top_k: int,
216
+ top_p: float
217
+ ) -> str:
218
+ return self.model.generate(
219
+ self.tokenizer,
220
+ prompt,
221
+ max_length=max_length,
222
+ temperature=temperature,
223
+ top_k=top_k,
224
+ top_p=top_p
225
+ )
226
+
227
+ # Create Gradio interface
228
+ generator = TextGenerator()
229
+
230
+ demo = gr.Blocks()
231
+
232
+ with demo:
233
+ gr.Markdown("# Neuromodulated Text Generator")
234
+
235
+ with gr.Tab("Train"):
236
+ with gr.Row():
237
+ with gr.Column():
238
+ train_input = gr.Textbox(
239
+ label="Training Text",
240
+ placeholder="Enter text to train on...",
241
+ lines=5
242
+ )
243
+ train_button = gr.Button("Train Model")
244
+
245
+ with gr.Column():
246
+ epochs_slider = gr.Slider(
247
+ label="Epochs",
248
+ minimum=1,
249
+ maximum=50,
250
+ value=10,
251
+ step=1
252
+ )
253
+ lr_slider = gr.Slider(
254
+ label="Learning Rate",
255
+ minimum=1e-5,
256
+ maximum=1e-3,
257
+ value=1e-4,
258
+ step=1e-5
259
+ )
260
+ batch_slider = gr.Slider(
261
+ label="Batch Size",
262
+ minimum=1,
263
+ maximum=32,
264
+ value=4,
265
+ step=1
266
+ )
267
+
268
+ train_output = gr.Textbox(label="Training Log")
269
+
270
+ with gr.Tab("Generate"):
271
+ with gr.Row():
272
+ with gr.Column():
273
+ prompt_input = gr.Textbox(
274
+ label="Prompt",
275
+ placeholder="Enter text prompt...",
276
+ lines=2
277
+ )
278
+ generate_button = gr.Button("Generate Text")
279
+
280
+ with gr.Column():
281
+ length_slider = gr.Slider(
282
+ label="Max Length",
283
+ minimum=10,
284
+ maximum=500,
285
+ value=100,
286
+ step=10
287
+ )
288
+ temp_slider = gr.Slider(
289
+ label="Temperature",
290
+ minimum=0.1,
291
+ maximum=2.0,
292
+ value=0.7,
293
+ step=0.1
294
+ )
295
+ topk_slider = gr.Slider(
296
+ label="Top-k",
297
+ minimum=0,
298
+ maximum=100,
299
+ value=50,
300
+ step=1
301
+ )
302
+ topp_slider = gr.Slider(
303
+ label="Top-p",
304
+ minimum=0.0,
305
+ maximum=1.0,
306
+ value=0.9,
307
+ step=0.05
308
+ )
309
+
310
+ generate_output = gr.Textbox(label="Generated Text")
311
+
312
+ train_button.click(
313
+ fn=generator.train_on_text,
314
+ inputs=[
315
+ train_input,
316
+ epochs_slider,
317
+ lr_slider,
318
+ batch_slider
319
+ ],
320
+ outputs=train_output
321
+ )
322
+
323
+ generate_button.click(
324
+ fn=generator.generate,
325
+ inputs=[
326
+ prompt_input,
327
+ length_slider,
328
+ temp_slider,
329
+ topk_slider,
330
+ topp_slider
331
+ ],
332
+ outputs=generate_output
333
+ )
334
+
335
+ if __name__ == "__main__":
336
+ demo.launch()