dappyx commited on
Commit
75c80a0
1 Parent(s): e833099

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +53 -0
  2. main.py +41 -0
  3. model.py +371 -0
  4. switch_transformer.pt +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from main import tokenizer, model, device
3
+ import torch
4
+
5
+ def qa_pipeline(text,question):
6
+ inputs = tokenizer(question, text, return_tensors="pt")
7
+ input_ids = inputs['input_ids'].to(device)
8
+ attention_mask = inputs['attention_mask'].to(device)
9
+ batch = {
10
+ "input_ids": input_ids,
11
+ "attention_mask": attention_mask
12
+ }
13
+ outputs = model(batch)
14
+
15
+ start_index = torch.argmax(outputs.start_logits, dim=-1).item()
16
+ end_index = torch.argmax(outputs.end_logits, dim=-1).item()
17
+
18
+ predict_answer_tokens = inputs.input_ids[0, start_index : end_index + 1]
19
+ return tokenizer.decode(predict_answer_tokens)
20
+
21
+ def answer_question(context, question):
22
+ result = qa_pipeline(context, question)
23
+ return result
24
+
25
+ example_contexts = [
26
+ "Қазақстанның ұлттық құрамы алуан түрлі. Халықтың басым бөлігін тұрғылықты қазақ халқы құрайды, пайыздық үлесі — 70,18%[10], орыстар — 18,42%, өзбектер — 3,29%, украиндар — 1,36%, ұйғырлар — 1,48%, татарлар — 1,06%, басқа халықтар 5,38%.[11] Халықтың 75% астамын мұсылмандар құрайды, православты христиандар — 21%, қалғаны басқа да дін өкілдері.[12]",
27
+ "Қазақстан бес мемлекетпен шекаралас, соның ішінде әлемдегі құрлықтағы ең ұзын шекара, солтүстігінде және батысында Ресеймен — 7591 км құрайды. Оңтүстігінде: Түрікменстан — 426 км, Өзбекстан — 2354 км және Қырғызстан — 1241 км, ал шығысында: Қытаймен — 1782 км шектеседі. Жалпы құрлық шекарасының ұзындығы — 13394 км. Батыста Каспий көлімен (2000 км), оңтүстік батыста Арал теңізімен шайылады.[9] 2024 жылдың 1 наурыздағы елдегі тұрғындар саны — 20 075 271[4], бұл әлем бойынша 64-орын. Жер көлемі жағынан әлем елдерінің ішінде 9-орын алады (2 724 902 км²).",
28
+ "Қазақстан — 1995 жылғы 30 тамыздағы республикалық референдумда қабылданған Конституция бойынша — өзін демократиялы, зайырлы, құқықты және әлеуметті мемлекет ретінде орнықтырды. Қазақстан Республикасы – президенттік басқару формасындағы біртұтас мемлекет. Республиканың ең жоғарғы өкілді органы — Парламент. Ол республиканың заң шығару құзіретін жүзеге асырады."
29
+ ]
30
+ example_questions = [
31
+ "Қазақстанның халқы неше пайызды қазақтар құрайды?",
32
+ "Қазақстан нешеу мемлекетпен шекаралас?",
33
+ "Қазақстандағы басқару формасы қандай?",
34
+ ]
35
+
36
+
37
+ examples = [[context, question] for context, question in zip(example_contexts, example_questions)]
38
+
39
+ # Создаем интерфейс
40
+ iface = gr.Interface(
41
+ fn=answer_question,
42
+ inputs=[
43
+ gr.Textbox(lines=10, label="Context"),
44
+ gr.Textbox(lines=2, label="Question")
45
+ ],
46
+ outputs="text",
47
+ title="Question Answering Model",
48
+ description="Введите контекст и задайте вопрос, чтобы получить ответ.",
49
+ examples=examples
50
+ )
51
+
52
+ # Запускаем интерфейс
53
+ iface.launch()
main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model import (
4
+ SwitchTransformer,
5
+ SwitchTransformerLayer,
6
+ MultiHeadAttention,
7
+ SwitchFeedForward,
8
+ FeedForward,
9
+ )
10
+ from transformers import AutoTokenizer
11
+
12
+ device = 'cpu'
13
+
14
+ ff = FeedForward(768, 768*4)
15
+ attn = MultiHeadAttention(8, 768, 0.2)
16
+ st_ff = SwitchFeedForward(
17
+ capacity_factor=1.25,
18
+ drop_tokens=False,
19
+ n_experts=4,
20
+ expert=ff,
21
+ d_model=768,
22
+ is_scale_prob=True,
23
+ )
24
+ st_layer = SwitchTransformerLayer(
25
+ d_model=768,
26
+ attn=attn,
27
+ feed_forward=st_ff,
28
+ dropout_prob=0.2
29
+ )
30
+ model = SwitchTransformer(
31
+ layer=st_layer,
32
+ n_layers=4,
33
+ n_experts=4,
34
+ device=device,
35
+ load_balancing_loss_ceof=0.05,
36
+ ).to(device)
37
+
38
+ model.load_state_dict(torch.load("switch_transformer.pt"))
39
+ tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")
40
+
41
+
model.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import forward
2
+ from torch import Tensor
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import torch
6
+ import copy
7
+ import math
8
+ from transformers import DistilBertForQuestionAnswering, DistilBertConfig
9
+ from transformers import AutoModelForQuestionAnswering
10
+
11
+ class MultiHeadAttention(nn.Module):
12
+ def __init__(self, n_heads, dim, dropout_prob):
13
+ super().__init__()
14
+
15
+ # self.n_heads = config.n_heads
16
+ # self.dim = config.dim
17
+ # self.dropout = nn.Dropout(p=config.attention_dropout)
18
+
19
+ self.n_heads = n_heads
20
+ self.dim = dim
21
+ self.dropout = nn.Dropout(p=dropout_prob)
22
+
23
+ assert self.dim % self.n_heads == 0
24
+ self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
25
+ self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
26
+ self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
27
+ self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
28
+
29
+ def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
30
+ """
31
+ Parameters:
32
+ query: torch.tensor(bs, seq_length, dim)
33
+ key: torch.tensor(bs, seq_length, dim)
34
+ value: torch.tensor(bs, seq_length, dim)
35
+ mask: torch.tensor(bs, seq_length)
36
+ Returns:
37
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
38
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
39
+ """
40
+ bs, q_length, dim = query.size()
41
+ k_length = key.size(1)
42
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
43
+ # assert key.size() == value.size()
44
+
45
+ dim_per_head = self.dim // self.n_heads
46
+
47
+ mask_reshp = (bs, 1, 1, k_length)
48
+
49
+ def shape(x):
50
+ """separate heads"""
51
+ return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
52
+
53
+ def unshape(x):
54
+ """group heads"""
55
+ return (
56
+ x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
57
+ )
58
+
59
+ q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
60
+ k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
61
+ v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
62
+
63
+ q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
64
+ scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
65
+ mask = (
66
+ (mask == 0).view(mask_reshp).expand_as(scores)
67
+ ) # (bs, n_heads, q_length, k_length)
68
+ scores = scores.masked_fill(
69
+ mask, -float("inf")
70
+ ) # (bs, n_heads, q_length, k_length)
71
+
72
+ weights = nn.functional.softmax(
73
+ scores, dim=-1
74
+ ) # (bs, n_heads, q_length, k_length)
75
+ weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
76
+
77
+ # Mask heads if we want to
78
+ if head_mask is not None:
79
+ weights = weights * head_mask
80
+
81
+ context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
82
+ context = unshape(context) # (bs, q_length, dim)
83
+ context = self.out_lin(context) # (bs, q_length, dim)
84
+
85
+ if output_attentions:
86
+ return (context, weights)
87
+ else:
88
+ return context
89
+
90
+
91
+ class FeedForward(nn.Module):
92
+ def __init__(self, dim_input: int = 768, dim_feedforward: int = 4 * 768):
93
+ super().__init__()
94
+
95
+ self.linear1 = nn.Linear(dim_input, dim_feedforward)
96
+ self.relu = nn.ReLU()
97
+ self.linear2 = nn.Linear(dim_feedforward, dim_input)
98
+
99
+ def forward(self, x):
100
+ return self.linear2(self.relu(self.linear1(x)))
101
+
102
+
103
+ class SwitchFeedForward(nn.Module):
104
+ """
105
+ ## Routing among multiple FFNs
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ capacity_factor: float,
112
+ drop_tokens: bool,
113
+ is_scale_prob: bool,
114
+ n_experts: int,
115
+ expert: FeedForward,
116
+ d_model: int
117
+ ):
118
+ """
119
+ * `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
120
+ * `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
121
+ * `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
122
+ * `n_experts` is the number of experts
123
+ * `expert` is the expert layer, a [FFN module](../feed_forward.html)
124
+ * `d_model` is the number of features in a token embedding
125
+ * `d_ff` is the number of features in the hidden layer of the FFN
126
+ * `dropout` is dropout probability in the FFN
127
+ """
128
+ super().__init__()
129
+
130
+ self.capacity_factor = capacity_factor
131
+ self.is_scale_prob = is_scale_prob
132
+ self.n_experts = n_experts
133
+ self.drop_tokens = drop_tokens
134
+
135
+ # make copies of the FFNs
136
+ self.experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(n_experts)])
137
+ # Routing layer and softmax
138
+ self.switch = nn.Linear(d_model, n_experts)
139
+ self.softmax = nn.Softmax(dim=-1)
140
+
141
+ def forward(self, x: torch.Tensor):
142
+ """
143
+ * `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
144
+ """
145
+
146
+ # Capture the shape to change shapes later
147
+ seq_len, batch_size, d_model = x.shape
148
+ # Flatten the sequence and batch dimensions
149
+ x = x.view(-1, d_model)
150
+
151
+ # Get routing probabilities for each of the tokens.
152
+ # $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
153
+ # where $N$ is the number of experts `n_experts` and
154
+ # $h(\cdot)$ is the linear transformation of token embeddings.
155
+ route_prob = self.softmax(self.switch(x))
156
+
157
+ # Get the maximum routing probabilities and the routes.
158
+ # We route to the expert with highest probability
159
+ route_prob_max, routes = torch.max(route_prob, dim=-1)
160
+
161
+ # Get indexes of tokens going to each expert
162
+ indexes_list = [
163
+ torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)
164
+ ]
165
+
166
+ # Initialize an empty tensor to store outputs
167
+ final_output = x.new_zeros(x.shape)
168
+
169
+ # Capacity of each expert.
170
+ # $$\mathrm{expert\;capacity} =
171
+ # \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
172
+ # \times \mathrm{capacity\;factor}$$
173
+ capacity = int(self.capacity_factor * len(x) / self.n_experts)
174
+ # Number of tokens routed to each expert.
175
+ counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
176
+
177
+ # Initialize an empty list of dropped tokens
178
+ dropped = []
179
+ # Only drop tokens if `drop_tokens` is `True`.
180
+ if self.drop_tokens:
181
+ # Drop tokens in each of the experts
182
+ for i in range(self.n_experts):
183
+ # Ignore if the expert is not over capacity
184
+ if len(indexes_list[i]) <= capacity:
185
+ continue
186
+ # Shuffle indexes before dropping
187
+ indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
188
+ # Collect the tokens over capacity as dropped tokens
189
+ dropped.append(indexes_list[i][capacity:])
190
+ # Keep only the tokens upto the capacity of the expert
191
+ indexes_list[i] = indexes_list[i][:capacity]
192
+
193
+ # Get outputs of the expert FFNs
194
+ expert_output = [
195
+ self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)
196
+ ]
197
+
198
+ # Assign to final output
199
+ for i in range(self.n_experts):
200
+ final_output[indexes_list[i], :] = expert_output[i]
201
+
202
+ # Pass through the dropped tokens
203
+ if dropped:
204
+ dropped = torch.cat(dropped)
205
+ final_output[dropped, :] = x[dropped, :]
206
+
207
+ if self.is_scale_prob:
208
+ # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
209
+ final_output = final_output * route_prob_max.view(-1, 1)
210
+ else:
211
+ # Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
212
+ # (this is something we experimented with).
213
+ final_output = final_output * (
214
+ route_prob_max / route_prob_max.detach()
215
+ ).view(-1, 1)
216
+
217
+ # Change the shape of the final output back to `[seq_len, batch_size, d_model]`
218
+ final_output = final_output.view(seq_len, batch_size, d_model)
219
+
220
+ # Return
221
+ #
222
+ # * the final output
223
+ # * number of tokens routed to each expert
224
+ # * sum of probabilities for each expert
225
+ # * number of tokens dropped.
226
+ # * routing probabilities of the selected experts
227
+ #
228
+ # These are used for the load balancing loss and logging
229
+ return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
230
+
231
+
232
+ class SwitchTransformerLayer(nn.Module):
233
+ """
234
+ # Switch Transformer Block
235
+ This is the same as [normal transformer block](../models.html#TransformerLayer)
236
+ with handling extra outputs of switch feedforward module.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ *,
242
+ d_model: int,
243
+ attn: MultiHeadAttention,
244
+ feed_forward: SwitchFeedForward,
245
+ dropout_prob: float
246
+ ):
247
+ """
248
+ * `d_model` is the token embedding size
249
+ * `attn` is the attention module
250
+ * `feed_forward` is the feed forward module (which is the switching module in this case)
251
+ * `dropout_prob` is the probability of dropping out after self attention and FFN
252
+ """
253
+ super().__init__()
254
+ self.size = d_model
255
+ self.attn = attn
256
+ self.feed_forward = feed_forward
257
+ self.dropout = nn.Dropout(dropout_prob)
258
+ self.norm_self_attn = nn.LayerNorm([d_model])
259
+ self.norm_ff = nn.LayerNorm([d_model])
260
+
261
+ def forward(self, *, x: torch.Tensor, mask: torch.Tensor):
262
+ # Normalize the vectors before doing self attention
263
+ z = self.norm_self_attn(x)
264
+ # Run through self attention, i.e. keys and values are from self
265
+ self_attn = self.attn(query=z, key=z, value=z, mask=mask)
266
+ # Add the self attention results
267
+ x = x + self.dropout(self_attn)
268
+
269
+ # Normalize for feed-forward
270
+ z = self.norm_ff(x)
271
+ # Pass through the switching feed-forward network
272
+ ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
273
+ # Add the feed-forward results back
274
+ x = x + self.dropout(ff)
275
+
276
+ return x, counts, route_prob, n_dropped, route_prob_max
277
+
278
+
279
+ class SwitchTransformer(nn.Module):
280
+ """
281
+ ## Switch Transformer
282
+ """
283
+
284
+ def __init__(self, layer, n_layers, n_experts, device, load_balancing_loss_ceof):
285
+ super().__init__()
286
+ # Make copies of the transformer layer
287
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
288
+ # Final normalization layer
289
+ self.norm = nn.LayerNorm([layer.size])
290
+ self.qa_outputs = nn.Linear(768, 2)
291
+ model = AutoModelForQuestionAnswering.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz").to(device)
292
+ self.base_model = model
293
+ self.device = device
294
+ self.load_balancing_loss_ceof = load_balancing_loss_ceof
295
+ self.n_experts = n_experts # used to calculate lb loss
296
+
297
+ def freeze_base_model(self):
298
+ for param in self.base_model.parameters():
299
+ param.requires_grad = False
300
+
301
+ def freeze_experts(self):
302
+ # TODO: find how to freeze the experts in the SwitchTransformer
303
+ pass
304
+
305
+ # def forward(self, x: torch.Tensor, mask: torch.Tensor):
306
+ def forward(self, batch):
307
+ input_ids = batch["input_ids"].to(self.device)
308
+ attention_mask = batch["attention_mask"].to(self.device)
309
+ start_positions = (
310
+ batch["start_positions"].to(self.device)
311
+ if "start_positions" in batch.keys()
312
+ else None
313
+ )
314
+ end_positions = (
315
+ batch["end_positions"].to(self.device)
316
+ if "end_positions" in batch.keys()
317
+ else None
318
+ )
319
+
320
+ outputs = self.base_model(
321
+ input_ids,
322
+ attention_mask=attention_mask,
323
+ start_positions=None,
324
+ end_positions=None,
325
+ output_hidden_states=True,
326
+ )
327
+ x = outputs.hidden_states[-1]
328
+ # Run through each transformer layer
329
+ counts, route_prob, n_dropped, route_prob_max = [], [], [], []
330
+ for layer in self.layers:
331
+ x, f, p, n_d, p_max = layer(x=x, mask=attention_mask)
332
+ counts.append(f)
333
+ route_prob.append(p)
334
+ n_dropped.append(n_d)
335
+ route_prob_max.append(p_max)
336
+ # Finally, normalize the vectors
337
+ output = self.norm(x)
338
+
339
+ logits = self.qa_outputs(output)
340
+ start_logits, end_logits = logits.split(1, dim=-1)
341
+ start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
342
+ end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
343
+
344
+ loss = None
345
+ if start_positions is not None and end_positions is not None:
346
+ if len(start_positions.size()) > 1:
347
+ start_positions = start_positions.squeeze(-1)
348
+ if len(end_positions.size()) > 1:
349
+ end_positions = end_positions.squeeze(-1)
350
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
351
+ ignored_index = start_logits.size(1)
352
+ start_positions = start_positions.clamp(0, ignored_index)
353
+ end_positions = end_positions.clamp(0, ignored_index)
354
+
355
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
356
+ start_loss = loss_fct(start_logits, start_positions)
357
+ end_loss = loss_fct(end_logits, end_positions)
358
+ loss = (start_loss + end_loss) / 2
359
+ counts = torch.stack(counts)
360
+ route_prob = torch.stack(route_prob)
361
+ route_prob_max = torch.stack(route_prob_max)
362
+ total = counts.sum(dim=-1, keepdims=True)
363
+ route_frac = counts / total
364
+ route_prob = route_prob / total
365
+ load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
366
+ loss = (
367
+ load_balancing_loss
368
+ if loss is None
369
+ else loss + self.load_balancing_loss_ceof * load_balancing_loss
370
+ )
371
+ return start_logits, end_logits, loss
switch_transformer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18db93cbc33e8aab35f5583010b67d2ca0c44cd93445e0bfd5d886382708d9ba
3
+ size 671685785