Sin2pi commited on
Commit
2ec825f
·
verified ·
1 Parent(s): e22b4d3

Create echo.py

Browse files
Files changed (1) hide show
  1. echo.py +1032 -0
echo.py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import base64, gzip, math, os, functools, warnings, numpy as np, torch, transformers, aiohttp, torch.nn.functional as F, evaluate, json, random
3
+ from torch import Tensor, amp, optim, nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ from torch.utils.tensorboard.writer import SummaryWriter
6
+ from threading import Thread
7
+ from typing import Dict, Optional, Tuple, Union, List, Any
8
+ from dataclasses import dataclass
9
+ from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments, PretrainedConfig, TrainerCallback, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizerFast)
10
+ from torch.optim import Optimizer
11
+ import evaluate
12
+ from evaluate import module
13
+ from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
14
+ from datasets import load_dataset, IterableDatasetDict, Audio, load_from_disk
15
+ from torch.nn.functional import scaled_dot_product_attention
16
+ transformers.utils.logging.set_verbosity_error()
17
+ warnings.filterwarnings(action="ignore")
18
+ warnings.warn = lambda *args, **kwargs: None
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ dtype = torch.float32
21
+
22
+
23
+ class Linear(nn.Linear):
24
+ def forward(self, x: Tensor) -> Tensor:# type: ignore
25
+ return F.linear(x, self.weight.to(x.dtype),
26
+ None if self.bias is None else self.bias.to(x.dtype))
27
+
28
+ class Conv1d(nn.Conv1d):
29
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:# type: ignore
30
+ return super()._conv_forward(x, weight.to(x.dtype),
31
+ None if bias is None else bias.to(x.dtype))
32
+
33
+ class LayerNorm(nn.LayerNorm):
34
+ def forward(self, x: Tensor) -> Tensor: # type: ignore
35
+ return super().forward(x.float()).type(x.dtype)
36
+
37
+
38
+
39
+ class CombinedRotaryEmbedding(nn.Module):
40
+ def __init__(self, base, dims, head, theta_learnable=True, rot_learnable=True,
41
+ matrix_learnable=False, freq_learnable=True):
42
+ super(CombinedRotaryEmbedding, self).__init__()
43
+
44
+ self.base = base
45
+ self.dims = dims
46
+ self.head = head
47
+
48
+ self.h_dim = self.dims // self.head
49
+ self.rot = (self.dims // self.head) // 2
50
+
51
+ self.thetas = nn.Parameter(torch.zeros(self.rot))
52
+ self.r_pairs = nn.Parameter(data=torch.rand(self.rot, 2) * self.h_dim)
53
+
54
+ self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)
55
+ self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)
56
+
57
+ self.r_matrix = nn.Parameter(torch.eye(n=self.h_dim), requires_grad=matrix_learnable)
58
+
59
+ freq_data = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
60
+ self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)
61
+
62
+ self.orthogonal_reg_weight = 0.01
63
+
64
+ def blended_rotation_matrix(self, dims, i, j, theta):
65
+ G = torch.eye(dims).to(theta.device)
66
+ G[i, i] = torch.cos(theta)
67
+ G[i, j] = -torch.sin(theta)
68
+ G[j, i] = torch.sin(theta)
69
+ G[j, j] = torch.cos(theta)
70
+
71
+ v = torch.zeros(dims).to(theta.device)
72
+ v[i] = torch.cos(theta)
73
+ v[j] = torch.sin(theta)
74
+ H = torch.eye(dims).to(theta.device) - 2 * torch.outer(v, v) / torch.dot(v, v)
75
+
76
+ R = torch.eye(dims).to(theta.device)
77
+ R[i, i] = torch.cos(theta)
78
+ R[i, j] = -torch.sin(theta)
79
+ R[j, i] = torch.sin(theta)
80
+ R[j, j] = torch.cos(theta)
81
+
82
+ return (G + H + R) / 3
83
+
84
+ def apply_blended_rotation(self, x):
85
+ adjusted_rot = int(torch.round(self.rot_scale * self.rot))
86
+ for k in range(adjusted_rot):
87
+ i, j = self.r_pairs[k].long()
88
+ theta = self.thetas[k] * self.theta_scale
89
+ B = self.blended_rotation_matrix(dims=self.h_dim, i=i, j=j, theta=theta)
90
+ x = torch.matmul(input=x, other=B)
91
+ return x
92
+
93
+ def update_base(self, new_base):
94
+ if new_base is not None and new_base != self.base:
95
+ self.base = new_base
96
+ inv_freq = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
97
+ self.inv_freq.data.copy_(inv_freq)
98
+ self.update_pairs()
99
+
100
+ def reset_parameters(self):
101
+ nn.init.orthogonal_(self.r_matrix)
102
+ nn.init.zeros_(self.thetas)
103
+ nn.init.zeros_(self.r_pairs)
104
+ nn.init.ones_(self.theta_scale)
105
+ nn.init.ones_(self.rot_scale)
106
+
107
+ def orthogonal_regularization_term(self):
108
+ loss = torch.tensor(0.0, device=self.r_matrix.device)
109
+ if self.r_matrix.requires_grad:
110
+ product = torch.matmul(self.r_matrix, self.r_matrix.t())
111
+ identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
112
+ loss = ((product - identity) ** 2).sum()
113
+ return self.orthogonal_reg_weight * loss
114
+
115
+ def update_pairs(self):
116
+ pairs = []
117
+ while len(pairs) < self.rot:
118
+ i, j = torch.randint(0, self.h_dim - 1, (2,))
119
+ if i != j and (i, j) not in pairs and (j, i) not in pairs:
120
+ pairs.append((i, j))
121
+ self.r_pairs.data.copy_(torch.tensor(pairs, dtype=torch.float32))
122
+
123
+ def forward(self, x, global_step=None):
124
+ if x.dim() not in [3, 4]:
125
+ raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
126
+
127
+ batch_size, seq_len, *rest = x.size()
128
+
129
+ if x.dim() == 3:
130
+ dims = rest[0]
131
+ if dims != self.head * self.h_dim:
132
+ raise ValueError(f"Expected dims ({dims}) to be compatible with head ({self.head}) * h_dim ({self.h_dim}={self.head * self.h_dim})")
133
+ else:
134
+ head, h_dim = rest
135
+ if head != self.head or h_dim != self.h_dim:
136
+ raise ValueError(f"For 4D input, expected head {self.head} and h_dim {self.h_dim}, but got head {head} and h_dim {h_dim}")
137
+
138
+ x = x.view(batch_size, seq_len, self.head, self.h_dim)
139
+ x = x.reshape(-1, self.h_dim)
140
+
141
+ x = self.apply_blended_rotation(x)
142
+
143
+ x = torch.matmul(input=x, other=self.r_matrix)
144
+
145
+ x = x.view(batch_size, seq_len, self.head, self.h_dim)
146
+
147
+ sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
148
+ sin = sinusoid_inp.sin()[None, :, None, :]
149
+ cos = sinusoid_inp.cos()[None, :, None, :]
150
+
151
+ x1, x2 = x[..., ::2], x[..., 1::2]
152
+ x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
153
+ x = x.view(batch_size, seq_len, self.dims)
154
+
155
+ return x
156
+
157
+ class SinusoidalEmbedding(nn.Module):
158
+ def __init__(self, n_ctx, dims, checkpoint):
159
+ super().__init__()
160
+ self.n_ctx = n_ctx
161
+ self.dims = dims
162
+ self.checkpoint = checkpoint
163
+
164
+ position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
165
+ div_term = torch.exp(torch.arange(0, dims, 2).float() * -(math.log(10000.0) / dims))
166
+ features = torch.zeros(n_ctx, dims)
167
+ features[:, 0::2] = torch.sin(position * div_term)
168
+ features[:, 1::2] = torch.cos(position * div_term)
169
+ self.register_buffer('my_big_toe', features)
170
+ self.pos_embeds = nn.Parameter(self.my_big_toe.clone())
171
+
172
+ def forward(self, positions):
173
+ if self.checkpoint:
174
+ position_embeddings = checkpoint(lambda x: self.pos_embeds[x], positions)
175
+ else:
176
+ position_embeddings = self.pos_embeds[positions]
177
+ return F.normalize(position_embeddings, p=2, dim=-1)
178
+
179
+ class CombinedPositionalEmbedding(nn.Module):
180
+ def __init__(self, base, dims, head, n_ctx, theta_learnable=True, rot_learnable=True,
181
+ matrix_learnable=False, freq_learnable=True, checkpoint=False):
182
+ super().__init__()
183
+ self.rotary_embedding = CombinedRotaryEmbedding(base, dims, head, theta_learnable,
184
+ rot_learnable, matrix_learnable, freq_learnable)
185
+ self.sinusoidal_embedding = SinusoidalEmbedding(n_ctx, dims, checkpoint)
186
+
187
+ def forward(self, x, positions, global_step=None):
188
+ rotary_embed = self.rotary_embedding(x, global_step)
189
+ sinusoidal_embed = self.sinusoidal_embedding(positions)
190
+
191
+ combined_embedding = rotary_embed + sinusoidal_embed
192
+ return combined_embedding
193
+
194
+ class MultiheadAttention(nn.Module):
195
+ use_sdpa = True
196
+
197
+ def __init__(self, base, dims, head, max_dist):
198
+ super().__init__()
199
+ assert dims % head == 0, "dims must be divisible by head"
200
+ self.head = head
201
+ self.h_dim = dims // head
202
+ assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
203
+
204
+ self.query = nn.Linear(dims, dims)
205
+ self.key = nn.Linear(dims, dims, bias=False)
206
+ self.value = nn.Linear(dims, dims)
207
+ self.out = nn.Linear(dims, dims)
208
+
209
+ # self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
210
+
211
+ def forward(self, x, xa = None, mask = None, kv_cache = None):
212
+
213
+ q = self.query(x)
214
+
215
+ if kv_cache is None or xa is None or self.key not in kv_cache:
216
+ k = self.key(x if xa is None else xa)
217
+ v = self.value(x if xa is None else xa)
218
+
219
+ else:
220
+ k = kv_cache[self.key]
221
+ v = kv_cache[self.value]
222
+
223
+ # q = self.givens_rotary(q)
224
+ # k = self.givens_rotary(k)
225
+
226
+ wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)
227
+
228
+ out = self.out(wv)
229
+ return out, qk
230
+
231
+ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
232
+
233
+ n_batch, n_ctx, dims = q.shape
234
+ scale = (dims // self.head) ** -0.25
235
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
236
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
237
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
238
+
239
+ if MultiheadAttention.use_sdpa:
240
+ a = scaled_dot_product_attention(query=q, key=k, value=v, is_causal=mask is not None and n_ctx > 1)
241
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
242
+ qk = None
243
+ else:
244
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
245
+ if mask is not None:
246
+ qk = qk + mask[:n_ctx, :n_ctx]
247
+ qk = qk.float()
248
+
249
+ w = F.softmax(qk, dim=-1).to(dtype=q.dtype)
250
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
251
+ qk = qk.detach()
252
+
253
+ return out, qk
254
+
255
+ class AdaptiveSpanAttention(nn.Module):
256
+ def __init__(self, base, dims, head, max_dist, sharpen, win_size, max_span, temp_scale=0.01):
257
+ super().__init__()
258
+ self.max_dist = max_dist
259
+ self.win_size = win_size
260
+ self.max_span = max_span
261
+ self.temp_scale = temp_scale
262
+ self.multihead_attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
263
+ self.span_scale = nn.Parameter(torch.tensor(1.0))
264
+ self.sharpen = sharpen
265
+
266
+ def forward(self, query, key, value, span_scale):
267
+ span_len = int(self.max_span * span_scale.mean().item())
268
+ span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])
269
+ eff_span = min(span_len, self.max_dist)
270
+
271
+ q_span = query[:, :eff_span, :]
272
+ k_span = key[:, :eff_span, :]
273
+ v_span = value[:, :eff_span, :]
274
+
275
+ batch_size, _, dims = query.shape
276
+ scale = (dims // self.multihead_attn.head) ** -0.25
277
+
278
+ q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
279
+ k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
280
+ v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
281
+
282
+ if self.sharpen:
283
+ temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
284
+ else:
285
+ temperature = 0.5 + self.temp_scale * span_scale.mean().item()
286
+
287
+ attn_scores = torch.matmul(q, k.transpose(-2, -1))
288
+ attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)
289
+ attn_out = torch.matmul(attn_weights, v)
290
+ attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
291
+ attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)
292
+
293
+ return attn_out, attn_weights
294
+
295
+ class SpanPredictor(nn.Module):
296
+ def __init__(self, dims):
297
+ super().__init__()
298
+ self.linear = nn.Linear(in_features=dims, out_features=1)
299
+
300
+ def forward(self, global_out):
301
+ scale = torch.sigmoid(self.linear(global_out))
302
+ return scale
303
+
304
+ class HybridAttention(nn.Module):
305
+ def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32):
306
+ super().__init__()
307
+ self.max_dist = max_dist
308
+ self.win_size = win_size
309
+ self.max_span = max_span
310
+ self.slid_win = slid_win
311
+
312
+ self.span_pred = SpanPredictor(dims=dims)
313
+ self.dist_local = max_dist
314
+ self.dist_global = max_dist
315
+
316
+ self.attn_local = AdaptiveSpanAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen, win_size=win_size, max_span=max_span)
317
+ self.attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=self.dist_global)
318
+ self.ln_local = LayerNorm(normalized_shape=dims)
319
+ self.ln_global = LayerNorm(normalized_shape=dims)
320
+ self.projection = Linear(in_features=2 * dims, out_features=dims)
321
+
322
+ def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):
323
+ local = self.ln_local(x)
324
+ globe = self.ln_global(x)
325
+
326
+ globe_out, _ = self.attn_global(globe, globe, globe)
327
+
328
+ span_scale = self.span_pred(globe_out.mean(dim=1))
329
+
330
+ win_size = max(1, int(self.slid_win * span_scale.mean().item()))
331
+ span_len = max(1, int(self.max_span * span_scale.mean().item()))
332
+
333
+ effective_max_dist = min(self.max_dist, local.size(1))
334
+ local_max_dist = min(self.dist_local, span_len, win_size)
335
+ globe_max_dist = effective_max_dist
336
+
337
+ self.attn_local.max_dist = local_max_dist
338
+ self.attn_global.max_dist = globe_max_dist
339
+
340
+ local_out = self.slide_win(x=local, win_size=win_size, span_len=span_len, span_scale=span_scale)
341
+
342
+ combined = torch.cat(tensors=[local_out, globe_out], dim=-1)
343
+ x = self.projection(combined)
344
+
345
+ return x
346
+
347
+ def slide_win(self, x, win_size, span_len, span_scale):
348
+ batch_size, seq_len, dims = x.size()
349
+ out = torch.zeros_like(x, device=x.device)
350
+
351
+ for i in range(0, seq_len, win_size):
352
+ end = min(i + win_size, seq_len)
353
+ query = x[:, i:end, :]
354
+
355
+ start = max(0, i - span_len + win_size)
356
+ key = x[:, start:i + span_len, :]
357
+ value = x[:, start:i + span_len, :]
358
+ attn_out, _ = self.attn_local(query, key, value, span_scale)
359
+ out[:, i:end, :] = attn_out
360
+
361
+ return out
362
+
363
+ class ResidualAttention(nn.Module):
364
+ def __init__(self, base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen):
365
+ super().__init__()
366
+
367
+ if hybrid:
368
+ self.attn = HybridAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen)
369
+ self.attn_ln = LayerNorm(normalized_shape=dims)
370
+ else:
371
+ self.attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
372
+ self.attn_ln = LayerNorm(normalized_shape=dims)
373
+
374
+ n_mlp = dims * 4
375
+ self.mlp = nn.Sequential(Linear(in_features=dims, out_features=n_mlp), nn.GELU(), Linear(in_features=n_mlp, out_features=dims))
376
+ self.mlp_ln = LayerNorm(normalized_shape=dims)
377
+
378
+ def forward(self, x, mask=None, kv_cache=None):
379
+ x = self._attn_forward(x=x, mask=mask, kv_cache=kv_cache)
380
+ x = self._mlp_forward(x=x)
381
+ return x
382
+
383
+ def _attn_forward(self, x, mask=None, kv_cache=None):
384
+ residual = x
385
+ x = self.attn_ln(x)
386
+
387
+ if isinstance(self.attn, HybridAttention):
388
+ attn_output = self.attn(x)
389
+
390
+ x = residual + attn_output
391
+ else:
392
+ attn_output, _ = self.attn(x, mask=mask, kv_cache=kv_cache)
393
+ x = residual + attn_output
394
+ return x
395
+
396
+ def _mlp_forward(self, x):
397
+ residual = x
398
+ x = self.mlp_ln(x)
399
+ return residual + self.mlp(x)
400
+
401
+ class AudioEncoder(nn.Module):
402
+ def __init__(self, base, mels, dims, head, n_layer, n_ctx, max_dist,
403
+ win_size, max_span, hybrid, checkpoint, cross, sharpen):
404
+ super().__init__()
405
+ self.conv1 = Conv1d(in_channels=mels, out_channels=dims, kernel_size=3, padding=1)
406
+ self.conv2 = Conv1d(in_channels=dims, out_channels=dims, kernel_size=3, stride=2, padding=1)
407
+ self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)
408
+ self.checkpoint = checkpoint
409
+
410
+ self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
411
+ # self.combine = CombinedPositionalEmbedding(base=base, dims=dims, head=head)
412
+ self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])
413
+ self.ln_post = LayerNorm(normalized_shape=dims)
414
+
415
+ def forward(self, x):
416
+ if self.checkpoint:
417
+ x = checkpoint(self._conv_forward, x)
418
+ else:
419
+ x = self._conv_forward(x)
420
+
421
+ for block in self.blocks:
422
+ if self.checkpoint:
423
+ x = checkpoint(block, x)
424
+ else:
425
+ x = block(x)
426
+ return self.ln_post(x)
427
+
428
+ def _conv_forward(self, x):
429
+ x = F.gelu(self.conv1(x))
430
+ x = F.gelu(self.conv2(x))
431
+ x = x.permute(0, 2, 1)
432
+
433
+ p = self.pos_embed(torch.arange(end=x.size(dim=1), device=x.device)).unsqueeze(0)
434
+ x = (x + p).to(x.dtype)
435
+ x = self.givens_rotary(x)
436
+ # x = self.combine(x)
437
+ return x
438
+
439
+ class TextDecoder(nn.Module):
440
+ def __init__(self, base, vocab, dims, head, n_layer, n_ctx, max_dist,
441
+ win_size, max_span, hybrid, checkpoint, cross, sharpen):
442
+ super().__init__()
443
+
444
+ self.tok_embed = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
445
+ self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)
446
+ self.checkpoint = checkpoint
447
+
448
+ self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
449
+
450
+ self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])
451
+
452
+ self.ln_post = LayerNorm(normalized_shape=dims)
453
+ self.ln = LayerNorm(normalized_shape=dims)
454
+
455
+ mask = torch.empty(n_ctx, n_ctx).fill_(value=-np.inf).triu_(diagonal=1)
456
+ self.register_buffer(name="mask", tensor=mask, persistent=False)
457
+ self.mask=mask
458
+
459
+ def forward(self, x, xa, kv_cache=None):
460
+ if self.checkpoint:
461
+ x = checkpoint(self._embedding_forward, x, xa, kv_cache)
462
+ else:
463
+ x = self._embedding_forward(x=x, xa=xa, kv_cache=kv_cache)
464
+
465
+ for block in self.blocks:
466
+ if self.checkpoint:
467
+ x = checkpoint(block, x, self.mask, kv_cache)
468
+ else:
469
+ x = block(x, self.mask, kv_cache)
470
+
471
+ x = self.ln(x)
472
+ x = (x @ torch.transpose(input=self.tok_embed.weight.to(dtype=x.dtype), dim0=0, dim1=1)).float()
473
+ return x
474
+
475
+ def _embedding_forward(self, x, xa, kv_cache):
476
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
477
+ positions = torch.arange(x.shape[1], device=x.device) + offset
478
+ pos_emb = self.pos_embed(positions).unsqueeze(0)
479
+ x = self.tok_embed(x) + pos_emb
480
+ x = self.givens_rotary(x)
481
+ return x
482
+
483
+ class EchoConfig(PretrainedConfig):
484
+ model_type = "Echo"
485
+ def __init__(
486
+ self,
487
+ checkpoint=False,
488
+ cross=False,
489
+ hybrid=False,
490
+ sharpen=False,
491
+ a_ctx=1500,
492
+ a_head=16,
493
+ a_layer=8,
494
+ a_dims=1024,
495
+ mels=128,
496
+ t_ctx=448,
497
+ t_head=8,
498
+ t_layer=8,
499
+ t_dims=1024,
500
+ win_size=64,
501
+ max_span=64,
502
+ max_dist=64,
503
+ base=10000,
504
+ pad_token_id=50257,
505
+ unk_token_id=50257,
506
+ vocab=51865,
507
+ eos_token_id=50257,
508
+ bos_token_id=50257,
509
+ decoder_start_token_id=50258,
510
+ **kwargs,
511
+ ):
512
+
513
+ super().__init__(**kwargs)
514
+ self.base = base
515
+ self.bos_token_id = bos_token_id
516
+ self.checkpoint = checkpoint
517
+ self.cross = cross
518
+ self.decoder_start_token_id = decoder_start_token_id
519
+ self.eos_token_id = eos_token_id
520
+ self.hybrid = hybrid
521
+ self.max_dist = max_dist
522
+ self.max_span = max_span
523
+ self.a_ctx = a_ctx
524
+ self.a_head = a_head
525
+ self.a_layer = a_layer
526
+ self.a_dims = a_dims
527
+ self.mels = mels
528
+ self.t_ctx = t_ctx
529
+ self.t_head = t_head
530
+ self.t_layer = t_layer
531
+ self.t_dims = t_dims
532
+ self.pad_token_id = pad_token_id
533
+ self.unk_token_id = unk_token_id
534
+ self.vocab = vocab
535
+ self.win_size = win_size
536
+ self.sharpen=sharpen
537
+
538
+ class Echo(nn.Module):
539
+ def __init__(self, config: EchoConfig):
540
+ super().__init__()
541
+ self.config = config
542
+
543
+ self.encoder = AudioEncoder(
544
+ base=self.config.base,
545
+ mels=self.config.mels,
546
+ dims=self.config.a_dims,
547
+ head=self.config.a_head,
548
+ n_layer=self.config.a_layer,
549
+ n_ctx=self.config.a_ctx,
550
+ max_dist=self.config.max_dist,
551
+ win_size=self.config.win_size,
552
+ max_span=self.config.max_span,
553
+ hybrid=self.config.hybrid,
554
+ checkpoint=self.config.checkpoint,
555
+ cross=self.config.cross,
556
+ sharpen=self.config.sharpen,
557
+ )
558
+
559
+ self.decoder = TextDecoder(
560
+ base=self.config.base,
561
+ vocab=self.config.vocab,
562
+ dims=self.config.t_dims,
563
+ head=self.config.t_head,
564
+ n_layer=self.config.t_layer,
565
+ n_ctx=self.config.t_ctx,
566
+ max_dist=self.config.max_dist,
567
+ win_size=self.config.win_size,
568
+ max_span=self.config.max_span,
569
+ hybrid=self.config.hybrid,
570
+ checkpoint=self.config.checkpoint,
571
+ cross=self.config.cross,
572
+ sharpen=self.config.sharpen,
573
+ )
574
+
575
+ all_heads = torch.zeros(self.config.t_layer, self.config.t_head, dtype=torch.bool)
576
+ all_heads[self.config.t_layer // 2:] = True
577
+ self.register_buffer(name="alignment_heads", tensor=all_heads.to_sparse(), persistent=False)
578
+
579
+ self.base = self.config.base
580
+ self.win_size = self.config.win_size
581
+ self.adjust_counter = 0
582
+ self.best_loss = float('inf')
583
+ self.kv_cache = {}
584
+
585
+ @property
586
+ def device(self):
587
+ return next(self.parameters()).device
588
+
589
+ def embed_audio(self, mel: torch.Tensor):
590
+ return self.encoder(mel)
591
+
592
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
593
+ return self.decoder(tokens, audio_features)
594
+
595
+ def update_window(self, new_window):
596
+ self.win_size = new_window
597
+ for module in self.modules():
598
+ if isinstance(module, HybridAttention):
599
+ module.update_window(self.win_size)
600
+
601
+ def adjust_window(self, loss, factor=1.00005):
602
+ if self.adjust_counter % 10 == 0:
603
+ if loss < self.best_loss:
604
+ new_window = self.win_size * factor
605
+ else:
606
+ new_window = self.win_size / factor
607
+ self.update_window(new_window=new_window)
608
+ self.best_loss = loss
609
+ self.adjust_counter += 1
610
+ return new_window
611
+ return self.win_size
612
+
613
+ def adjust_base(self, loss, factor=1.0025) -> float | int:
614
+ if self.adjust_counter % 25 == 0:
615
+ if loss < self.best_loss:
616
+ new_base=self.base*factor
617
+ else:
618
+ new_base=self.base/factor
619
+ self.update_base(new_base=new_base)
620
+ self.base=new_base
621
+ self.best_loss=loss
622
+ self.adjust_counter += 1
623
+ return self.base
624
+
625
+ def update_base(self, new_base):
626
+ self.new_base=new_base
627
+ for name, module in self.encoder.named_modules():
628
+ if isinstance(module, (CombinedRotaryEmbedding)):
629
+ module.update_base(new_base=self.new_base)
630
+
631
+ @staticmethod
632
+ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
633
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
634
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
635
+ shifted_input_ids[:, 0] = decoder_start_token_id
636
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
637
+ return shifted_input_ids
638
+
639
+ def forward(self, input_features, labels=None, dec_input_ids=None) -> dict[str, Any | None]:
640
+ if labels is not None:
641
+ if dec_input_ids is None:
642
+ dec_input_ids = self.shift_tokens_right(
643
+ input_ids=labels, pad_token_id=self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
644
+ )
645
+
646
+ encoded_features = self.encoder(input_features).to(self.device)
647
+ logits = self.decoder(dec_input_ids, encoded_features)
648
+
649
+ loss = None
650
+ if labels is not None:
651
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
652
+ labels = labels.to(logits.device).long()
653
+ loss = loss_fct(logits.view(-1, self.config.vocab), labels.view(-1))
654
+
655
+ self.adjust_window(loss.item())
656
+ # self.adjust_base(loss=loss.item())
657
+ return {"loss": loss, "logits": logits}
658
+
659
+ def reset_parameters(self):
660
+ for name, module in self.encoder.named_modules():
661
+ if isinstance(module, CombinedRotaryEmbedding):
662
+ module.reset_parameters()
663
+
664
+ def _initialize_weights(self, module):
665
+ nn.init.normal_(tensor=self.decoder.tok_embed.weight, mean=0.0, std=0.02)
666
+ nn.init.constant_(tensor=self.decoder.ln.weight, val=1)
667
+ nn.init.constant_(tensor=self.decoder.ln.bias, val=0)
668
+ nn.init.xavier_normal_(tensor=self.encoder.conv1.weight)
669
+ nn.init.zeros_(tensor=self.encoder.conv1.bias)
670
+ nn.init.kaiming_normal_(tensor=self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
671
+ nn.init.zeros_(tensor=self.encoder.conv2.bias)
672
+ nn.init.constant_(tensor=self.encoder.ln_post.weight, val=1)
673
+ nn.init.constant_(tensor=self.encoder.ln_post.bias, val=0)
674
+
675
+ for block in self.decoder.blocks:
676
+ for layer in block.children():
677
+ if isinstance(layer, nn.Linear):
678
+ nn.init.xavier_normal_(tensor=layer.weight)
679
+ nn.init.zeros_(tensor=layer.bias)
680
+ if isinstance(layer, LayerNorm):
681
+ nn.init.constant_(tensor=layer.weight, val=1)
682
+
683
+ for block in self.encoder.blocks:
684
+ for layer in block.children():
685
+ if isinstance(layer, nn.Linear):
686
+ nn.init.xavier_normal_(tensor=layer.weight)
687
+ nn.init.zeros_(tensor=layer.bias)
688
+ if isinstance(layer, LayerNorm):
689
+ nn.init.constant_(tensor=layer.weight, val=1)
690
+
691
+ for module in self.encoder.named_modules():
692
+ if isinstance(module, CombinedRotaryEmbedding):
693
+ nn.init.constant_(tensor=module.thetas, val=1)
694
+ nn.init.constant_(tensor=module.r_matrix, val=1)
695
+ nn.init.constant_(tensor=module.r_pairs, val=1)
696
+ nn.init.constant_(tensor=module.inv_freq, val=1)
697
+
698
+ def apply_initialization(self, module):
699
+ self._initialize_weights(module=module)
700
+
701
+ from datetime import datetime
702
+ log_dir = os.path.join('./output/Echo/', datetime.now().strftime(format='%m-%d_%H'))
703
+ os.makedirs(name=log_dir, exist_ok=True)
704
+
705
+ config = EchoConfig(
706
+ checkpoint=False,
707
+ cross=False,
708
+ hybrid=False,
709
+ sharpen=False,
710
+ audio_ctx=1500,
711
+ audio_head=4,
712
+ audio_layer=4,
713
+ audio_dims=512,
714
+ mels=128,
715
+ text_ctx=448,
716
+ text_head=4,
717
+ text_layer=4,
718
+ text_dims=512,
719
+ win_size=16,
720
+ max_span=16,
721
+ max_dist=16,
722
+ base=50000,
723
+ pad_token_id=50257,
724
+ unk_token_id=50257,
725
+ vocab=51865,
726
+ eos_token_id=50257,
727
+ bos_token_id=50257,
728
+ decoder_start_token_id=50258,
729
+
730
+ )
731
+
732
+ model = Echo(config=config).to(device=device)
733
+ model.apply_initialization(module=model)
734
+
735
+
736
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
737
+ pretrained_model_name_or_path="openai/whisper-small",
738
+ feature_size=128, sample_rate=160000, do_normalize=True)
739
+
740
+ tokenizer = WhisperTokenizerFast.from_pretrained(
741
+ pretrained_model_name_or_path="openai/whisper-small",
742
+ language="en", task="transcribe")
743
+
744
+ processor = WhisperProcessor.from_pretrained(
745
+ pretrained_model_name_or_path="openai/whisper-small",
746
+ feature_size=128, sample_rate=160000, do_normalize=True,
747
+ language="en", task="transcribe")
748
+
749
+ class GradientClippingCallback(TrainerCallback):
750
+ def on_step_end(self, args, dims, control, **kwargs):
751
+ torch.nn.utils.clip_grad_norm_(parameters=kwargs["model"].parameters(), max_norm=0.98)
752
+
753
+ @dataclass
754
+ class DataCollatorSpeechSeq2SeqWithPadding:
755
+ processor: Any
756
+ decoder_start_token_id: int
757
+
758
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
759
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
760
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
761
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
762
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
763
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
764
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
765
+ labels = labels[:, 1:]
766
+ batch["labels"] = labels
767
+ return batch
768
+
769
+ def get_length_of_dataset(dataset):
770
+ length = 0
771
+ for item in dataset:
772
+ length += len(item["audio"]["array"]) / item["audio"]["sampling_rate"]
773
+ return length / 3600
774
+
775
+ def prepare_dataset(batch):
776
+ audio = batch["audio"]
777
+ batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
778
+ batch["labels"] = tokenizer(batch["sentence"]).input_ids
779
+ return batch
780
+
781
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=config.decoder_start_token_id)
782
+
783
+ datasets = IterableDatasetDict()
784
+
785
+ datasets["train"] = load_dataset(
786
+ path="mozilla-foundation/common_voice_17_0", token="",
787
+ name="en", split="train", streaming=True, trust_remote_code=True).take(10000)
788
+
789
+ datasets["test"] = load_dataset(
790
+ path="mozilla-foundation/common_voice_17_0", token="",
791
+ name="en", split="test", streaming=True, trust_remote_code=True).take(100)
792
+
793
+ dataset = datasets.cast_column(column="audio", feature=Audio(sampling_rate=16000))
794
+
795
+ dataset = dataset.map(function=prepare_dataset,
796
+ remove_columns=list(next(iter(dataset.values())).features)).with_format(type="torch")
797
+
798
+ class MetricsCallback(TrainerCallback):
799
+ def __init__(self, tb_writer, tokenizer, metric, optimizer, scheduler, log_every_n_steps=1):
800
+ super().__init__()
801
+ self.tb_writer = tb_writer
802
+ self.tokenizer = tokenizer
803
+ self.metric = metric
804
+ self.optimizer = optimizer
805
+ self.scheduler = scheduler
806
+ self.log_every_n_steps = log_every_n_steps
807
+ self.predictions = None
808
+ self.label_ids = None
809
+
810
+ def compute_wer(self, pred_str, label_str):
811
+ wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)
812
+ return wer
813
+
814
+ def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
815
+ if metrics is not None:
816
+ self.eval_loss = metrics.get('eval_loss')
817
+
818
+ current_learning_rate = self.optimizer.param_groups[0]['lr']
819
+ if state.global_step % self.log_every_n_steps == 0:
820
+ self.tb_writer.add_scalar('learning_rate', current_learning_rate, state.global_step)
821
+ print(f"Learning Rate: {current_learning_rate:.8f}")
822
+
823
+ self.tb_writer.add_scalar('eval_loss', self.eval_loss, state.global_step)
824
+
825
+ for key, value in metrics.items():
826
+ if key.startswith("eval_"):
827
+ self.tb_writer.add_scalar(key, value, state.global_step)
828
+
829
+ if self.predictions is not None and self.label_ids is not None:
830
+ pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)
831
+ label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)
832
+
833
+ if state.global_step % self.log_every_n_steps == 0:
834
+ total_samples = len(pred_str)
835
+ random_indices = random.sample(range(total_samples), 1)
836
+
837
+ for sample_index in random_indices:
838
+ self.tb_writer.add_text(f"Prediction_{sample_index}", pred_str[sample_index], state.global_step)
839
+ self.tb_writer.add_text(f"Label_{sample_index}", label_str[sample_index], state.global_step)
840
+ print(f"Evaluation: - Step {state.global_step} - Loss: {self.eval_loss:.2f}")
841
+ print(f"Prediction: {pred_str[sample_index]}")
842
+ print(f"Label: {label_str[sample_index]}")
843
+ print("-" * 10)
844
+
845
+ self.predictions = None
846
+ self.label_ids = None
847
+
848
+ def create_compute_metrics(callback_instance):
849
+ def compute_metrics(eval_pred):
850
+ pred_logits = eval_pred.predictions
851
+ label_ids = eval_pred.label_ids
852
+
853
+ if isinstance(pred_logits, tuple):
854
+ pred_ids = pred_logits[0]
855
+ else:
856
+ pred_ids = pred_logits
857
+ if pred_ids.ndim == 3:
858
+ pred_ids = np.argmax(pred_ids, axis=-1)
859
+
860
+ label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id
861
+ callback_instance.predictions = pred_ids
862
+ callback_instance.label_ids = label_ids
863
+ pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
864
+ label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
865
+ wer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)
866
+ pred_flat = pred_ids.flatten()
867
+ labels_flat = label_ids.flatten()
868
+ mask = labels_flat != callback_instance.tokenizer.pad_token_id
869
+
870
+ accuracy = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])
871
+ precision = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
872
+ recall = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
873
+ f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
874
+ return {"wer": wer, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
875
+ return compute_metrics
876
+
877
+ metric = evaluate.load(path="wer")
878
+ tb_writer = SummaryWriter(log_dir=log_dir)
879
+
880
+ training_args = Seq2SeqTrainingArguments(
881
+ output_dir=log_dir,
882
+ per_device_train_batch_size=1,
883
+ per_device_eval_batch_size=1,
884
+ gradient_accumulation_steps=1,
885
+ eval_accumulation_steps=1,
886
+ tf32=True,
887
+ bf16=True,
888
+ eval_strategy="steps",
889
+ save_strategy="steps",
890
+ max_steps=10000,
891
+ save_steps=10000,
892
+ eval_steps=100,
893
+ warmup_steps=100,
894
+ logging_steps=10,
895
+ logging_dir=log_dir + "/logs_hf",
896
+ report_to=["tensorboard"],
897
+ load_best_model_at_end=False,
898
+ metric_for_best_model="loss",
899
+ greater_is_better=False,
900
+ push_to_hub=False,
901
+ disable_tqdm=False,
902
+ save_total_limit=1,
903
+ remove_unused_columns=False,
904
+ label_names=["labels"],
905
+ eval_on_start=True,
906
+ )
907
+
908
+ class MaxFactor(Optimizer):
909
+ def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(None, 1e-3), d=1.0,
910
+ weight_decay=0.0, gamma=0.99, eps_rms=1e-8, maximize=False):
911
+
912
+ defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
913
+ gamma=gamma, eps_rms=eps_rms, maximize=maximize)
914
+
915
+ super().__init__(params, defaults)
916
+
917
+ @torch.no_grad()
918
+ def step(self, closure=None):
919
+ loss = None
920
+ if closure is not None:
921
+ with torch.enable_grad():
922
+ loss = closure()
923
+
924
+ for group in self.param_groups:
925
+ params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
926
+ eps1, eps2 = group["eps"]
927
+ for p in group["params"]:
928
+ if p.grad is None:
929
+ continue
930
+ grad = p.grad
931
+ if grad.dtype in {torch.float16, torch.bfloat16}:
932
+ grad = grad.float()
933
+
934
+ state = self.state[p]
935
+ if len(state) == 0:
936
+ state["step"] = torch.tensor(0.0, dtype=torch.float32)
937
+ if p.grad.dim() > 1:
938
+ row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
939
+ row_shape[-1], col_shape[-2] = 1, 1
940
+ state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
941
+ state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
942
+
943
+ row_vars.append(state.get("row_var", None))
944
+ col_vars.append(state.get("col_var", None))
945
+ v.append(state["v"])
946
+ state_steps.append(state["step"])
947
+ params_with_grad.append(p)
948
+ grads.append(grad)
949
+
950
+ for i, param in enumerate(params_with_grad):
951
+ grad = grads[i]
952
+
953
+ if group["maximize"]:
954
+ grad = -grad
955
+ step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
956
+
957
+ if eps1 is None:
958
+ eps1 = torch.finfo(param.dtype).eps
959
+
960
+ step_t += 1
961
+ step_float = step_t.item()
962
+ one_minus_beta2_t = step_float ** group["beta2_decay"]
963
+ rho_t = min(group["lr"], 1 / (step_float ** 0.5))
964
+ alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
965
+
966
+ if group["weight_decay"]!= 0:
967
+ param.mul_(1 - group["lr"] * group["weight_decay"])
968
+
969
+ if grad.dim() > 1:
970
+ row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
971
+ row_var.lerp_(row_mean, one_minus_beta2_t)
972
+ col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
973
+ col_var.lerp_(col_mean, one_minus_beta2_t)
974
+ var_estimate = row_var @ col_var
975
+ max_row_var = row_var.max(dim=-2, keepdim=True)[0]
976
+ var_estimate.div_(max_row_var.clamp_(min=eps1))
977
+
978
+ else:
979
+ vi.mul_(group["gamma"]).add_(1 - group["gamma"], grad ** 2)
980
+ var_estimate = vi
981
+
982
+ update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
983
+ update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
984
+ denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
985
+ param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
986
+
987
+ return loss
988
+
989
+ optimizer = MaxFactor(
990
+ model.parameters(),
991
+ lr=0.025,
992
+ beta2_decay=-0.8,
993
+ eps=(None, 1e-4),
994
+ d=1.0,
995
+ weight_decay=0.0025,
996
+ gamma=0.99,
997
+ eps_rms=1e-8,
998
+ maximize=False,
999
+ )
1000
+
1001
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1002
+ optimizer=optimizer,
1003
+ T_max=training_args.max_steps,
1004
+ eta_min=0.0,
1005
+ last_epoch=-1
1006
+ )
1007
+
1008
+ metrics_callback = MetricsCallback(tb_writer=tb_writer, tokenizer=tokenizer, metric=metric, optimizer=optimizer, scheduler=scheduler, log_every_n_steps=10)
1009
+ compute_metrics = create_compute_metrics(callback_instance=metrics_callback)
1010
+
1011
+ trainer = Seq2SeqTrainer(
1012
+ args=training_args,
1013
+ model=model,
1014
+ train_dataset=dataset["train"],
1015
+ eval_dataset=dataset["test"],
1016
+ data_collator=data_collator,
1017
+ compute_metrics=compute_metrics,
1018
+ processing_class=feature_extractor,
1019
+ callbacks=[metrics_callback],
1020
+ optimizers=(optimizer, scheduler)
1021
+ )
1022
+
1023
+ trainer.train(resume_from_checkpoint=False)
1024
+
1025
+ from tensorboard import program
1026
+ log_dir = "D:/new/tensorboard3"
1027
+ tb = program.TensorBoard()
1028
+ tb.configure(argv=[None, '--logdir', log_dir])
1029
+ url = tb.launch()
1030
+ print(f"TensorBoard started at {url}")
1031
+
1032
+