Sin2pi commited on
Commit
9a03329
·
verified ·
1 Parent(s): ebbfb14

Create focus.py

Browse files
Files changed (1) hide show
  1. focus.py +257 -0
focus.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Adaptivefocus(nn.Module):
2
+ def __init__(self, base, dims, head, max_dist, sharpen, win_size, max_span, temp_scale=0.01, num_iterations=3):
3
+ super().__init__()
4
+ self.max_dist = max_dist
5
+ self.win_size = win_size
6
+ self.max_span = max_span
7
+ self.temp_scale = temp_scale
8
+ self.multihead_attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
9
+ self.span_scale = nn.Parameter(torch.tensor(1.0))
10
+ self.sharpen = sharpen
11
+ self.num_iterations = num_iterations
12
+ self.base_threshold = 1e-4
13
+ self.scaling_factor = 0.1
14
+
15
+ def _focus(self, query, key, value, span_scale):
16
+ max_iterations = self.num_iterations
17
+ iteration = 0
18
+ prev_attn_out = torch.zeros_like(query)
19
+
20
+ while iteration < max_iterations:
21
+ span_len = int(self.max_span * span_scale.mean().item())
22
+ span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])
23
+ eff_span = min(span_len, self.max_dist)
24
+
25
+ if eff_span == 0:
26
+ break
27
+
28
+ q_span = query[:, :eff_span, :]
29
+ k_span = key[:, :eff_span, :]
30
+ v_span = value[:, :eff_span, :]
31
+
32
+ batch_size, seq_len, dims = q_span.size()
33
+ scale = (dims // self.multihead_attn.head) ** -0.25
34
+
35
+ q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
36
+ k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
37
+ v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
38
+
39
+ if self.sharpen:
40
+ temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
41
+ else:
42
+ temperature = 0.5 + self.temp_scale * span_scale.mean().item()
43
+
44
+ attn_scores = torch.matmul(q, k.transpose(-2, -1))
45
+ attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)
46
+ attn_out = torch.matmul(attn_weights, v)
47
+ attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
48
+ attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)
49
+
50
+ diff = torch.abs(attn_out - prev_attn_out).mean()
51
+
52
+ dynamic_threshold = self.base_threshold + self.scaling_factor * diff
53
+
54
+ if diff < dynamic_threshold:
55
+ break
56
+
57
+ prev_attn_out = attn_out
58
+ query = query + attn_out
59
+ iteration += 1
60
+
61
+ return attn_out, attn_weights
62
+
63
+ def forward(self, query, key, value, span_scale):
64
+ return self._focus(query, key, value, span_scale)
65
+
66
+
67
+ class SpanPredictor(nn.Module):
68
+ def __init__(self, dims):
69
+ super().__init__()
70
+ self.linear = nn.Linear(in_features=dims, out_features=1)
71
+
72
+ def forward(self, global_out):
73
+ scale = torch.sigmoid(self.linear(global_out))
74
+ return scale
75
+
76
+ class FocusedAttention(nn.Module):
77
+ def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32, temp_scale=0.01, num_iterations=3):
78
+ super().__init__()
79
+ self.max_dist = max_dist
80
+ self.win_size = win_size
81
+ self.max_span = max_span
82
+ self.slid_win = slid_win
83
+
84
+ self.span_pred = SpanPredictor(dims=dims)
85
+ self.dist_local = max_dist
86
+ self.dist_global = max_dist
87
+
88
+ self.attn_local = Adaptivefocus(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen, win_size=win_size, max_span=max_span, temp_scale=temp_scale, num_iterations=num_iterations)
89
+ self.attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=self.dist_global)
90
+ self.ln_local = LayerNorm(normalized_shape=dims)
91
+ self.ln_global = LayerNorm(normalized_shape=dims)
92
+ self.projection = Linear(in_features=2 * dims, out_features=dims)
93
+
94
+ def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):
95
+ local = self.ln_local(x)
96
+ globe = self.ln_global(x)
97
+
98
+ globe_out, _ = self.attn_global(globe, globe, globe)
99
+
100
+ span_scale = self.span_pred(globe_out.mean(dim=1))
101
+
102
+ win_size = max(1, int(self.slid_win * span_scale.mean().item()))
103
+ span_len = max(1, int(self.max_span * span_scale.mean().item()))
104
+
105
+ effective_max_dist = min(self.max_dist, local.size(1))
106
+ local_max_dist = min(self.dist_local, span_len, win_size)
107
+ globe_max_dist = effective_max_dist
108
+
109
+ self.attn_local.max_dist = local_max_dist
110
+ self.attn_global.max_dist = globe_max_dist
111
+
112
+ local_out = self.slide_win(x=local, win_size=win_size, span_len=span_len, span_scale=span_scale)
113
+
114
+ combined = torch.cat(tensors=[local_out, globe_out], dim=-1)
115
+ x = self.projection(combined)
116
+
117
+ return x
118
+
119
+ def slide_win(self, x, win_size, span_len, span_scale):
120
+ batch_size, seq_len, dims = x.size()
121
+ out = torch.zeros_like(x, device=x.device)
122
+
123
+ for i in range(0, seq_len, win_size):
124
+ end = min(i + win_size, seq_len)
125
+ query = x[:, i:end, :]
126
+
127
+ start = max(0, i - span_len + win_size)
128
+ key = x[:, start:i + span_len, :]
129
+ value = x[:, start:i + span_len, :]
130
+ attn_out, _ = self.attn_local(query, key, value, span_scale)
131
+ out[:, i:end, :] = attn_out
132
+
133
+ return out
134
+
135
+ ## different version
136
+ # class FocusedAttention(nn.Module):
137
+ # def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32, temp_scale=0.01):
138
+ # super().__init__()
139
+ # self.base = base
140
+ # self.dims = dims
141
+ # self.head = head
142
+ # self.max_dist = max_dist
143
+ # self.sharpen = sharpen
144
+ # self.win_size = win_size
145
+ # self.max_span = max_span
146
+ # self.slid_win = slid_win
147
+ # self.temp_scale = temp_scale
148
+
149
+ # self.span_scale_param = nn.Parameter(torch.tensor(1.0))
150
+ # self.span_predictor = nn.Linear(in_features=dims, out_features=1)
151
+
152
+ # self.multihead_attn_local = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
153
+ # self.multihead_attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
154
+
155
+ # self.ln_local = LayerNorm(normalized_shape=dims)
156
+ # self.ln_global = LayerNorm(normalized_shape=dims)
157
+ # self.projection = Linear(in_features=2 * dims, out_features=dims)
158
+
159
+ # def forward(self, x):
160
+
161
+ # local = self.ln_local(x)
162
+ # global_ = self.ln_global(x)
163
+
164
+ # globe_out, _ = self.multihead_attn_global(global_, global_, global_)
165
+
166
+ # span_scale = torch.sigmoid(self.span_predictor(globe_out.mean(dim=1)))
167
+
168
+ # win_size = max(1, int(self.slid_win * span_scale.mean().item()))
169
+ # span_len = max(1, int(self.max_span * span_scale.mean().item()))
170
+
171
+ # effective_max_dist = min(self.max_dist, local.size(1))
172
+ # local_max_dist = min(self.max_dist, span_len, win_size)
173
+ # globe_max_dist = effective_max_dist
174
+
175
+ # self.multihead_attn_local.max_dist = local_max_dist
176
+ # self.multihead_attn_global.max_dist = globe_max_dist
177
+
178
+ # local_out = self._window(local, win_size, span_len, span_scale)
179
+
180
+ # combined = torch.cat([local_out, globe_out], dim=-1)
181
+ # x = self.projection(combined)
182
+
183
+ # return x
184
+
185
+ # def _window(self, x, win_size, span_len, span_scale):
186
+ # batch_size, seq_len, dims = x.size()
187
+ # output = torch.zeros_like(x, device=x.device)
188
+
189
+ # for i in range(0, seq_len, win_size):
190
+ # end = min(i + win_size, seq_len)
191
+ # query = x[:, i:end, :]
192
+
193
+ # start = max(0, i - span_len + win_size)
194
+ # key = x[:, start:i + span_len, :]
195
+ # value = x[:, start:i + span_len, :]
196
+
197
+ # attn_out, _ = self._focus(query, key, value, span_scale)
198
+ # output[:, i:end, :] = attn_out
199
+
200
+ # return output
201
+
202
+ # def _focus(self, query, key, value, span_scale):
203
+ # span_len = int(self.max_span * span_scale.mean().item())
204
+ # span_len = min(span_len, query.size(1), key.size(1), value.size(1))
205
+ # eff_span = min(span_len, self.max_dist)
206
+
207
+ # q_span = query[:, :eff_span, :]
208
+ # k_span = key[:, :eff_span, :]
209
+ # v_span = value[:, :eff_span, :]
210
+
211
+ # batch_size, seq_len, dims = q_span.size()
212
+ # scale_factor = (dims // self.head) ** -0.25
213
+
214
+ # q = q_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
215
+ # k = k_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
216
+ # v = v_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
217
+
218
+ # if self.sharpen:
219
+ # temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
220
+ # else:
221
+ # temperature = 0.5 + self.temp_scale * span_scale.mean().item()
222
+
223
+ # attn_scores = torch.matmul(q, k.transpose(-2, -1))
224
+ # attn_weights = torch.softmax((attn_scores / temperature) * scale_factor, dim=-1)
225
+ # attn_out = torch.matmul(attn_weights, v)
226
+
227
+ # attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)
228
+
229
+ # return attn_out, attn_weights
230
+
231
+
232
+ # #Batch:
233
+
234
+ # def _window(self, x, win_size, span_len, span_scale):
235
+ # batch_size, seq_len, dims = x.size()
236
+ # num_windows = (seq_len + win_size - 1) // win_size # Calculate the number of windows
237
+
238
+ # # Create tensors to store the outputs
239
+ # output = torch.zeros_like(x, device=x.device)
240
+
241
+ # # Iterate over the windows in a more efficient manner
242
+ # for i in range(num_windows):
243
+ # start_idx = i * win_size
244
+ # end_idx = min((i + 1) * win_size, seq_len)
245
+ # query = x[:, start_idx:end_idx, :]
246
+
247
+ # # Define the range of keys and values
248
+ # key_start = max(0, start_idx - span_len + win_size)
249
+ # key_end = min(start_idx + span_len, seq_len)
250
+ # key = x[:, key_start:key_end, :]
251
+ # value = x[:, key_start:key_end, :]
252
+
253
+ # attn_out, _ = self._focus(query, key, value, span_scale)
254
+ # output[:, start_idx:end_idx, :] = attn_out
255
+
256
+ # return output
257
+