Sin2pi commited on
Commit
ebbfb14
·
verified ·
1 Parent(s): 82d5bc7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -112
README.md CHANGED
@@ -55,116 +55,104 @@ Scenario: If the model were to be deployed on low power devices.
55
 
56
  Reasoning: You want to create as sparse of a weight distribution as possible, and this is done by a lower temperature.
57
 
58
- #### There are reasonings behind why one might want the opposite to be true when it comes to focus and that can be changed with a toggle sharpen_longer=False in your model config.
59
-
60
- class AdaptiveSpanAttention(nn.Module):
61
- def __init__(self, base, dims, head, max_dist, win_size, max_span, temp_scale=0.01, sharpen_longer=False):
62
- super().__init__()
63
-
64
- self.max_dist = max_dist
65
- self.win_size = win_size
66
- self.max_span = max_span
67
- self.temp_scale = temp_scale
68
- self.multihead_attn = MultiheadAttention(base, dims, head, max_dist)
69
- self.span_scale = nn.Parameter(torch.tensor(1.0))
70
- self.sharpen_longer = sharpen_longer
71
-
72
-
73
- def forward(self, query, key, value, span_scale):
74
- span_len = int(self.max_span * span_scale.mean().item())
75
- span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])
76
- eff_span = min(span_len, self.max_dist)
77
-
78
- q_span = query[:, :eff_span, :]
79
- k_span = key[:, :eff_span, :]
80
- v_span = value[:, :eff_span, :]
81
-
82
- attn_out, attn_weights = self.multihead_attn(q_span, k_span, v_span)
83
-
84
- if self.sharpen_longer:
85
- temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item()) # Sharper for longer spans
86
- else:
87
- temperature = 0.5 + self.temp_scale * span_scale.mean().item() # Sharper for shorter spans
88
-
89
- batch_size, _, dims = query.shape
90
- scale = (dims // self.multihead_attn.head) ** -0.25
91
-
92
- q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
93
- k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
94
- v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
95
-
96
- attn_scores = torch.matmul(q, k.transpose(-2, -1))
97
- attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)
98
- attn_out = torch.matmul(attn_weights, v)
99
- attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
100
- attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)
101
-
102
- return attn_out, attn_weights
103
 
104
- class SpanPredictor(nn.Module):
105
- def __init__(self, dims):
106
- super().__init__()
107
- self.linear = nn.Linear(dims, 1)
108
-
109
- def forward(self, global_out):
110
- scale = torch.sigmoid(self.linear(global_out))
111
- return scale
112
-
113
- class HybridAttention(nn.Module):
114
- def __init__(self, base, dims, head, max_dist, win_size=32, max_span=32, slid_win=32, sharpen_longer=False):
115
- super().__init__()
116
- self.max_dist = max_dist
117
- self.win_size = win_size
118
- self.max_span = max_span
119
- self.slid_win = slid_win
120
-
121
- self.span_pred = SpanPredictor(dims)
122
- self.dist_local = max_dist
123
- self.dist_global = max_dist
124
- self.attn_local = AdaptiveSpanAttention(base, dims, head, self.dist_local, win_size, max_span, sharpen_longer=sharpen_longer)
125
- self.attn_global = MultiheadAttention(base, dims, head, self.dist_global)
126
- self.ln_local = LayerNorm(dims)
127
- self.ln_global = LayerNorm(dims)
128
- self.projection = Linear(2 * dims, dims)
129
-
130
- def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):
131
-
132
- local = self.ln_local(x)
133
- globe = self.ln_global(x)
134
-
135
- globe_out, _ = self.attn_global(globe, globe, globe)
136
-
137
- span_scale = self.span_pred(globe_out.mean(dim=1))
138
-
139
- win_size = max(1, int(self.slid_win * span_scale.mean().item()))
140
- span_len = max(1, int(self.max_span * span_scale.mean().item()))
141
-
142
- effective_max_dist = min(self.max_dist, local.size(1))
143
- local_max_dist = min(self.dist_local, span_len, win_size)
144
- globe_max_dist = effective_max_dist
145
-
146
- # DYNAMICALLY UPDATE max_dist:
147
- self.attn_local.max_dist = local_max_dist
148
- self.attn_global.max_dist = globe_max_dist
149
-
150
- local_out = self.slide_win(local, win_size, span_len, span_scale)
151
-
152
- combined = torch.cat([local_out, globe_out], dim=-1)
153
- x = self.projection(combined)
154
-
155
- return x
156
-
157
- def slide_win(self, x, win_size, span_len, span_scale):
158
- batch_size, seq_len, dims = x.size()
159
- out = torch.zeros_like(x, device=x.device)
160
-
161
- for i in range(0, seq_len, win_size):
162
- end = min(i + win_size, seq_len)
163
- query = x[:, i:end, :]
164
-
165
- start = max(0, i - span_len + win_size)
166
- key = x[:, start:i + span_len, :]
167
- value = x[:, start:i + span_len, :]
168
- attn_out, _ = self.attn_local(query, key, value, span_scale)
169
- out[:, i:end, :] = attn_out
170
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  Reasoning: You want to create as sparse of a weight distribution as possible, and this is done by a lower temperature.
57
 
58
+ ### Focus block:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ class FocusAttention(nn.Module):
61
+ def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32, temp_scale=0.01):
62
+ super().__init__()
63
+ self.base = base
64
+ self.dims = dims
65
+ self.head = head
66
+ self.max_dist = max_dist
67
+ self.sharpen = sharpen
68
+ self.win_size = win_size
69
+ self.max_span = max_span
70
+ self.slid_win = slid_win
71
+ self.temp_scale = temp_scale
72
+
73
+ self.span_scale_param = nn.Parameter(torch.tensor(1.0))
74
+ self.span_predictor = nn.Linear(in_features=dims, out_features=1)
75
+
76
+ self.multihead_attn_local = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
77
+ self.multihead_attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
78
+
79
+ self.ln_local = LayerNorm(normalized_shape=dims)
80
+ self.ln_global = LayerNorm(normalized_shape=dims)
81
+ self.projection = Linear(in_features=2 * dims, out_features=dims)
82
+
83
+ def forward(self, x):
84
+
85
+ local = self.ln_local(x)
86
+ global_ = self.ln_global(x)
87
+
88
+ globe_out, _ = self.multihead_attn_global(global_, global_, global_)
89
+
90
+ span_scale = torch.sigmoid(self.span_predictor(globe_out.mean(dim=1)))
91
+
92
+ win_size = max(1, int(self.slid_win * span_scale.mean().item()))
93
+ span_len = max(1, int(self.max_span * span_scale.mean().item()))
94
+
95
+ effective_max_dist = min(self.max_dist, local.size(1))
96
+ local_max_dist = min(self.max_dist, span_len, win_size)
97
+ globe_max_dist = effective_max_dist
98
+
99
+ self.multihead_attn_local.max_dist = local_max_dist
100
+ self.multihead_attn_global.max_dist = globe_max_dist
101
+
102
+ local_out = self._window(local, win_size, span_len, span_scale)
103
+
104
+ combined = torch.cat([local_out, globe_out], dim=-1)
105
+ x = self.projection(combined)
106
+
107
+ return x
108
+
109
+ def _window(self, x, win_size, span_len, span_scale):
110
+ batch_size, seq_len, dims = x.size()
111
+ num_windows = (seq_len + win_size - 1) // win_size
112
+
113
+ output = torch.zeros_like(x, device=x.device)
114
+
115
+ for i in range(num_windows):
116
+ start_idx = i * win_size
117
+ end_idx = min((i + 1) * win_size, seq_len)
118
+ query = x[:, start_idx:end_idx, :]
119
+
120
+
121
+ key_start = max(0, start_idx - span_len + win_size)
122
+ key_end = min(start_idx + span_len, seq_len)
123
+ key = x[:, key_start:key_end, :]
124
+ value = x[:, key_start:key_end, :]
125
+
126
+ attn_out, _ = self._focus(query, key, value, span_scale)
127
+ output[:, start_idx:end_idx, :] = attn_out
128
+
129
+ return output
130
+
131
+ def _focus(self, query, key, value, span_scale):
132
+ span_len = int(self.max_span * span_scale.mean().item())
133
+ span_len = min(span_len, query.size(1), key.size(1), value.size(1))
134
+ eff_span = min(span_len, self.max_dist)
135
+
136
+ q_span = query[:, :eff_span, :]
137
+ k_span = key[:, :eff_span, :]
138
+ v_span = value[:, :eff_span, :]
139
+
140
+ batch_size, seq_len, dims = q_span.size()
141
+ scale_factor = (dims // self.head) ** -0.25
142
+
143
+ q = q_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
144
+ k = k_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
145
+ v = v_span.view(batch_size, seq_len, self.head, -1).permute(0, 2, 1, 3)
146
+
147
+ if self.sharpen:
148
+ temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
149
+ else:
150
+ temperature = 0.5 + self.temp_scale * span_scale.mean().item()
151
+
152
+ attn_scores = torch.matmul(q, k.transpose(-2, -1))
153
+ attn_weights = torch.softmax((attn_scores / temperature) * scale_factor, dim=-1)
154
+ attn_out = torch.matmul(attn_weights, v)
155
+
156
+ attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)
157
+
158
+ return attn_out, attn_weights