Sin2pi commited on
Commit
82d5bc7
·
verified ·
1 Parent(s): 058020e

Upload echo.ipynb

Browse files
Files changed (1) hide show
  1. echo.ipynb +1160 -0
echo.ipynb ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import base64, gzip, math, os, functools, warnings, numpy as np, torch, transformers, aiohttp, torch.nn.functional as F, evaluate, json, random\n",
10
+ "from torch import Tensor, amp, optim, nn\n",
11
+ "from torch.utils.checkpoint import checkpoint\n",
12
+ "from torch.utils.tensorboard.writer import SummaryWriter\n",
13
+ "from threading import Thread\n",
14
+ "from typing import Dict, Optional, Tuple, Union, List, Any\n",
15
+ "from transformers.modeling_utils import PreTrainedModel\n",
16
+ "from dataclasses import dataclass\n",
17
+ "from transformers.optimization import Adafactor, AdafactorSchedule\n",
18
+ "from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments, PretrainedConfig, TrainerCallback, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizerFast)\n",
19
+ "from torch.optim import Optimizer\n",
20
+ "import evaluate\n",
21
+ "from evaluate import module\n",
22
+ "from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score\n",
23
+ "from sklearn.model_selection import KFold, train_test_split\n",
24
+ "from datasets import load_dataset, Dataset, concatenate_datasets, IterableDatasetDict, Audio, load_from_disk\n",
25
+ "from torch.nn.functional import scaled_dot_product_attention\n",
26
+ "\n",
27
+ "from accelerate import Accelerator\n",
28
+ "import matplotlib.pyplot as plt\n",
29
+ "transformers.utils.logging.set_verbosity_error()\n",
30
+ "warnings.filterwarnings(action=\"ignore\")\n",
31
+ "warnings.warn = lambda *args, **kwargs: None\n",
32
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
33
+ "dtype = torch.float32\n",
34
+ "torch_dtype = torch.float32\n",
35
+ "torch.set_default_dtype(dtype)\n"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "\n",
45
+ "class Linear(nn.Linear):\n",
46
+ " def forward(self, x: Tensor) -> Tensor:# type: ignore\n",
47
+ " return F.linear(x, self.weight.to(x.dtype),\n",
48
+ " None if self.bias is None else self.bias.to(x.dtype))\n",
49
+ "\n",
50
+ "class Conv1d(nn.Conv1d):\n",
51
+ " def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:# type: ignore\n",
52
+ " return super()._conv_forward(x, weight.to(x.dtype),\n",
53
+ " None if bias is None else bias.to(x.dtype))\n",
54
+ "\n",
55
+ "class LayerNorm(nn.LayerNorm):\n",
56
+ " def forward(self, x: Tensor) -> Tensor: # type: ignore\n",
57
+ " return super().forward(x.float()).type(x.dtype) "
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "class CombinedRotaryEmbedding(nn.Module):\n",
67
+ " def __init__(self, base, dims, head, theta_learnable=True, rot_learnable=True,\n",
68
+ " matrix_learnable=False, freq_learnable=True):\n",
69
+ " super(CombinedRotaryEmbedding, self).__init__()\n",
70
+ "\n",
71
+ " self.base = base\n",
72
+ " self.dims = dims\n",
73
+ " self.head = head\n",
74
+ "\n",
75
+ " self.h_dim = self.dims // self.head\n",
76
+ " self.rot = (self.dims // self.head) // 2\n",
77
+ "\n",
78
+ " self.thetas = nn.Parameter(torch.zeros(self.rot))\n",
79
+ " self.r_pairs = nn.Parameter(data=torch.rand(self.rot, 2) * self.h_dim)\n",
80
+ "\n",
81
+ " self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)\n",
82
+ " self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)\n",
83
+ "\n",
84
+ " self.r_matrix = nn.Parameter(torch.eye(n=self.h_dim), requires_grad=matrix_learnable)\n",
85
+ "\n",
86
+ " freq_data = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))\n",
87
+ " self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)\n",
88
+ "\n",
89
+ " self.orthogonal_reg_weight = 0.01\n",
90
+ "\n",
91
+ " def blended_rotation_matrix(self, dims, i, j, theta):\n",
92
+ " G = torch.eye(dims).to(theta.device)\n",
93
+ " G[i, i] = torch.cos(theta)\n",
94
+ " G[i, j] = -torch.sin(theta)\n",
95
+ " G[j, i] = torch.sin(theta)\n",
96
+ " G[j, j] = torch.cos(theta)\n",
97
+ "\n",
98
+ " v = torch.zeros(dims).to(theta.device)\n",
99
+ " v[i] = torch.cos(theta)\n",
100
+ " v[j] = torch.sin(theta)\n",
101
+ " H = torch.eye(dims).to(theta.device) - 2 * torch.outer(v, v) / torch.dot(v, v)\n",
102
+ "\n",
103
+ " R = torch.eye(dims).to(theta.device)\n",
104
+ " R[i, i] = torch.cos(theta)\n",
105
+ " R[i, j] = -torch.sin(theta)\n",
106
+ " R[j, i] = torch.sin(theta)\n",
107
+ " R[j, j] = torch.cos(theta)\n",
108
+ "\n",
109
+ " return (G + H + R) / 3\n",
110
+ "\n",
111
+ " def apply_blended_rotation(self, x):\n",
112
+ " adjusted_rot = int(torch.round(self.rot_scale * self.rot))\n",
113
+ " for k in range(adjusted_rot):\n",
114
+ " i, j = self.r_pairs[k].long()\n",
115
+ " theta = self.thetas[k] * self.theta_scale\n",
116
+ " B = self.blended_rotation_matrix(dims=self.h_dim, i=i, j=j, theta=theta)\n",
117
+ " x = torch.matmul(input=x, other=B)\n",
118
+ " return x\n",
119
+ "\n",
120
+ " def update_base(self, new_base):\n",
121
+ " if new_base is not None and new_base != self.base:\n",
122
+ " self.base = new_base\n",
123
+ " inv_freq = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))\n",
124
+ " self.inv_freq.data.copy_(inv_freq)\n",
125
+ " self.update_pairs()\n",
126
+ "\n",
127
+ " def reset_parameters(self):\n",
128
+ " nn.init.orthogonal_(self.r_matrix)\n",
129
+ " nn.init.zeros_(self.thetas)\n",
130
+ " nn.init.zeros_(self.r_pairs)\n",
131
+ " nn.init.ones_(self.theta_scale)\n",
132
+ " nn.init.ones_(self.rot_scale)\n",
133
+ "\n",
134
+ " def orthogonal_regularization_term(self):\n",
135
+ " loss = torch.tensor(0.0, device=self.r_matrix.device)\n",
136
+ " if self.r_matrix.requires_grad:\n",
137
+ " product = torch.matmul(self.r_matrix, self.r_matrix.t())\n",
138
+ " identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)\n",
139
+ " loss = ((product - identity) ** 2).sum()\n",
140
+ " return self.orthogonal_reg_weight * loss\n",
141
+ "\n",
142
+ " def update_pairs(self):\n",
143
+ " pairs = []\n",
144
+ " while len(pairs) < self.rot:\n",
145
+ " i, j = torch.randint(0, self.h_dim - 1, (2,))\n",
146
+ " if i != j and (i, j) not in pairs and (j, i) not in pairs:\n",
147
+ " pairs.append((i, j))\n",
148
+ " self.r_pairs.data.copy_(torch.tensor(pairs, dtype=torch.float32))\n",
149
+ "\n",
150
+ " def forward(self, x, global_step=None):\n",
151
+ " if x.dim() not in [3, 4]:\n",
152
+ " raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
153
+ "\n",
154
+ " batch_size, seq_len, *rest = x.size()\n",
155
+ "\n",
156
+ " if x.dim() == 3:\n",
157
+ " dims = rest[0]\n",
158
+ " if dims != self.head * self.h_dim:\n",
159
+ " raise ValueError(f\"Expected dims ({dims}) to be compatible with head ({self.head}) * h_dim ({self.h_dim}={self.head * self.h_dim})\")\n",
160
+ " else:\n",
161
+ " head, h_dim = rest\n",
162
+ " if head != self.head or h_dim != self.h_dim:\n",
163
+ " raise ValueError(f\"For 4D input, expected head {self.head} and h_dim {self.h_dim}, but got head {head} and h_dim {h_dim}\")\n",
164
+ "\n",
165
+ " x = x.view(batch_size, seq_len, self.head, self.h_dim)\n",
166
+ " x = x.reshape(-1, self.h_dim)\n",
167
+ "\n",
168
+ " x = self.apply_blended_rotation(x)\n",
169
+ "\n",
170
+ " x = torch.matmul(input=x, other=self.r_matrix)\n",
171
+ "\n",
172
+ " x = x.view(batch_size, seq_len, self.head, self.h_dim)\n",
173
+ "\n",
174
+ " sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))\n",
175
+ " sin = sinusoid_inp.sin()[None, :, None, :]\n",
176
+ " cos = sinusoid_inp.cos()[None, :, None, :]\n",
177
+ "\n",
178
+ " x1, x2 = x[..., ::2], x[..., 1::2]\n",
179
+ " x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
180
+ " x = x.view(batch_size, seq_len, self.dims)\n",
181
+ "\n",
182
+ " return x\n",
183
+ "\n",
184
+ "class SinusoidalEmbedding(nn.Module):\n",
185
+ " def __init__(self, n_ctx, dims, checkpoint):\n",
186
+ " super().__init__()\n",
187
+ " self.n_ctx = n_ctx\n",
188
+ " self.dims = dims\n",
189
+ " self.checkpoint = checkpoint\n",
190
+ "\n",
191
+ " position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)\n",
192
+ " div_term = torch.exp(torch.arange(0, dims, 2).float() * -(math.log(10000.0) / dims))\n",
193
+ " features = torch.zeros(n_ctx, dims)\n",
194
+ " features[:, 0::2] = torch.sin(position * div_term)\n",
195
+ " features[:, 1::2] = torch.cos(position * div_term)\n",
196
+ " self.register_buffer('my_big_toe', features)\n",
197
+ " self.pos_embeds = nn.Parameter(self.my_big_toe.clone())\n",
198
+ "\n",
199
+ " def forward(self, positions):\n",
200
+ " if self.checkpoint:\n",
201
+ " position_embeddings = checkpoint(lambda x: self.pos_embeds[x], positions)\n",
202
+ " else:\n",
203
+ " position_embeddings = self.pos_embeds[positions]\n",
204
+ " return F.normalize(position_embeddings, p=2, dim=-1) \n",
205
+ "\n",
206
+ "class CombinedPositionalEmbedding(nn.Module):\n",
207
+ " def __init__(self, base, dims, head, n_ctx, theta_learnable=True, rot_learnable=True, \n",
208
+ " matrix_learnable=False, freq_learnable=True, checkpoint=False):\n",
209
+ " super().__init__()\n",
210
+ " self.rotary_embedding = CombinedRotaryEmbedding(base, dims, head, theta_learnable, \n",
211
+ " rot_learnable, matrix_learnable, freq_learnable)\n",
212
+ " self.sinusoidal_embedding = SinusoidalEmbedding(n_ctx, dims, checkpoint)\n",
213
+ "\n",
214
+ " def forward(self, x, positions, global_step=None):\n",
215
+ " rotary_embed = self.rotary_embedding(x, global_step)\n",
216
+ " sinusoidal_embed = self.sinusoidal_embedding(positions)\n",
217
+ " \n",
218
+ " combined_embedding = rotary_embed + sinusoidal_embed\n",
219
+ " return combined_embedding"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "\n",
229
+ "class MultiheadAttention(nn.Module):\n",
230
+ " use_sdpa = True\n",
231
+ "\n",
232
+ " def __init__(self, base, dims, head, max_dist):\n",
233
+ " super().__init__()\n",
234
+ " assert dims % head == 0, \"dims must be divisible by head\"\n",
235
+ " self.head = head\n",
236
+ " self.h_dim = dims // head\n",
237
+ " assert self.h_dim % 2 == 0, \"Head dimension must be even for rotary embeddings\"\n",
238
+ "\n",
239
+ " self.query = nn.Linear(dims, dims)\n",
240
+ " self.key = nn.Linear(dims, dims, bias=False)\n",
241
+ " self.value = nn.Linear(dims, dims)\n",
242
+ " self.out = nn.Linear(dims, dims)\n",
243
+ "\n",
244
+ " def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
245
+ "\n",
246
+ " q = self.query(x)\n",
247
+ "\n",
248
+ " if kv_cache is None or xa is None or self.key not in kv_cache:\n",
249
+ " k = self.key(x if xa is None else xa)\n",
250
+ " v = self.value(x if xa is None else xa)\n",
251
+ "\n",
252
+ " else:\n",
253
+ " k = kv_cache[self.key]\n",
254
+ " v = kv_cache[self.value]\n",
255
+ " wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)\n",
256
+ "\n",
257
+ " out = self.out(wv)\n",
258
+ " return out, qk\n",
259
+ " \n",
260
+ " def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):\n",
261
+ " \n",
262
+ " n_batch, n_ctx, dims = q.shape\n",
263
+ " scale = (dims // self.head) ** -0.25\n",
264
+ " q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)\n",
265
+ " k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)\n",
266
+ " v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)\n",
267
+ "\n",
268
+ " if MultiheadAttention.use_sdpa:\n",
269
+ " a = scaled_dot_product_attention(query=q, key=k, value=v, is_causal=mask is not None and n_ctx > 1)\n",
270
+ " out = a.permute(0, 2, 1, 3).flatten(start_dim=2)\n",
271
+ " qk = None\n",
272
+ " else:\n",
273
+ " qk = (q * scale) @ (k * scale).transpose(-1, -2)\n",
274
+ " if mask is not None:\n",
275
+ " qk = qk + mask[:n_ctx, :n_ctx]\n",
276
+ " qk = qk.float()\n",
277
+ "\n",
278
+ " w = F.softmax(qk, dim=-1).to(dtype=q.dtype)\n",
279
+ " out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n",
280
+ " qk = qk.detach()\n",
281
+ "\n",
282
+ " return out, qk\n",
283
+ " "
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "\n",
293
+ "class AdaptiveSpanAttention(nn.Module):\n",
294
+ " def __init__(self, base, dims, head, max_dist, sharpen, win_size, max_span, temp_scale=0.01):\n",
295
+ " super().__init__()\n",
296
+ " self.max_dist = max_dist\n",
297
+ " self.win_size = win_size\n",
298
+ " self.max_span = max_span\n",
299
+ " self.temp_scale = temp_scale\n",
300
+ " self.multihead_attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)\n",
301
+ " self.span_scale = nn.Parameter(torch.tensor(1.0))\n",
302
+ " self.sharpen = sharpen\n",
303
+ "\n",
304
+ " def forward(self, query, key, value, span_scale):\n",
305
+ " span_len = int(self.max_span * span_scale.mean().item())\n",
306
+ " span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])\n",
307
+ " eff_span = min(span_len, self.max_dist)\n",
308
+ "\n",
309
+ " q_span = query[:, :eff_span, :]\n",
310
+ " k_span = key[:, :eff_span, :]\n",
311
+ " v_span = value[:, :eff_span, :]\n",
312
+ "\n",
313
+ " batch_size, _, dims = query.shape\n",
314
+ " scale = (dims // self.multihead_attn.head) ** -0.25\n",
315
+ "\n",
316
+ " q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)\n",
317
+ " k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)\n",
318
+ " v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)\n",
319
+ "\n",
320
+ " if self.sharpen:\n",
321
+ " temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())\n",
322
+ " else:\n",
323
+ " temperature = 0.5 + self.temp_scale * span_scale.mean().item()\n",
324
+ "\n",
325
+ " attn_scores = torch.matmul(q, k.transpose(-2, -1))\n",
326
+ " attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)\n",
327
+ " attn_out = torch.matmul(attn_weights, v)\n",
328
+ " attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)\n",
329
+ " attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)\n",
330
+ "\n",
331
+ " return attn_out, attn_weights\n",
332
+ "\n",
333
+ "\n",
334
+ "class SpanPredictor(nn.Module):\n",
335
+ " def __init__(self, dims):\n",
336
+ " super().__init__()\n",
337
+ " self.linear = nn.Linear(in_features=dims, out_features=1)\n",
338
+ "\n",
339
+ " def forward(self, global_out):\n",
340
+ " scale = torch.sigmoid(self.linear(global_out))\n",
341
+ " return scale\n",
342
+ "\n",
343
+ "\n",
344
+ "class HybridAttention(nn.Module):\n",
345
+ " def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32):\n",
346
+ " super().__init__()\n",
347
+ " self.max_dist = max_dist\n",
348
+ " self.win_size = win_size\n",
349
+ " self.max_span = max_span\n",
350
+ " self.slid_win = slid_win\n",
351
+ "\n",
352
+ " self.span_pred = SpanPredictor(dims=dims)\n",
353
+ " self.dist_local = max_dist\n",
354
+ " self.dist_global = max_dist\n",
355
+ "\n",
356
+ " self.attn_local = AdaptiveSpanAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen, win_size=win_size, max_span=max_span)\n",
357
+ " self.attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=self.dist_global)\n",
358
+ " self.ln_local = LayerNorm(normalized_shape=dims)\n",
359
+ " self.ln_global = LayerNorm(normalized_shape=dims)\n",
360
+ " self.projection = Linear(in_features=2 * dims, out_features=dims)\n",
361
+ "\n",
362
+ " def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):\n",
363
+ " local = self.ln_local(x)\n",
364
+ " globe = self.ln_global(x)\n",
365
+ "\n",
366
+ " globe_out, _ = self.attn_global(globe, globe, globe)\n",
367
+ "\n",
368
+ " span_scale = self.span_pred(globe_out.mean(dim=1))\n",
369
+ "\n",
370
+ " win_size = max(1, int(self.slid_win * span_scale.mean().item()))\n",
371
+ " span_len = max(1, int(self.max_span * span_scale.mean().item()))\n",
372
+ "\n",
373
+ " effective_max_dist = min(self.max_dist, local.size(1))\n",
374
+ " local_max_dist = min(self.dist_local, span_len, win_size)\n",
375
+ " globe_max_dist = effective_max_dist\n",
376
+ "\n",
377
+ " self.attn_local.max_dist = local_max_dist\n",
378
+ " self.attn_global.max_dist = globe_max_dist\n",
379
+ "\n",
380
+ " local_out = self.slide_win(x=local, win_size=win_size, span_len=span_len, span_scale=span_scale)\n",
381
+ "\n",
382
+ " combined = torch.cat(tensors=[local_out, globe_out], dim=-1)\n",
383
+ " x = self.projection(combined)\n",
384
+ "\n",
385
+ " return x\n",
386
+ "\n",
387
+ " def slide_win(self, x, win_size, span_len, span_scale):\n",
388
+ " batch_size, seq_len, dims = x.size()\n",
389
+ " out = torch.zeros_like(x, device=x.device)\n",
390
+ "\n",
391
+ " for i in range(0, seq_len, win_size):\n",
392
+ " end = min(i + win_size, seq_len)\n",
393
+ " query = x[:, i:end, :]\n",
394
+ "\n",
395
+ " start = max(0, i - span_len + win_size)\n",
396
+ " key = x[:, start:i + span_len, :]\n",
397
+ " value = x[:, start:i + span_len, :]\n",
398
+ " attn_out, _ = self.attn_local(query, key, value, span_scale)\n",
399
+ " out[:, i:end, :] = attn_out\n",
400
+ "\n",
401
+ " return out\n",
402
+ "\n",
403
+ "\n"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "\n",
413
+ "class ResidualAttention(nn.Module):\n",
414
+ " def __init__(self, base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen):\n",
415
+ " super().__init__()\n",
416
+ "\n",
417
+ " if hybrid:\n",
418
+ " # print(\"HybridDrive ON\")\n",
419
+ " self.attn = HybridAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen)\n",
420
+ " self.attn_ln = LayerNorm(normalized_shape=dims)\n",
421
+ " else:\n",
422
+ " self.attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)\n",
423
+ " self.attn_ln = LayerNorm(normalized_shape=dims)\n",
424
+ "\n",
425
+ " n_mlp = dims * 4\n",
426
+ " self.mlp = nn.Sequential(Linear(in_features=dims, out_features=n_mlp), nn.GELU(), Linear(in_features=n_mlp, out_features=dims))\n",
427
+ " self.mlp_ln = LayerNorm(normalized_shape=dims)\n",
428
+ "\n",
429
+ " def forward(self, x, mask=None, kv_cache=None):\n",
430
+ " x = self._attn_forward(x=x, mask=mask, kv_cache=kv_cache)\n",
431
+ " x = self._mlp_forward(x=x)\n",
432
+ " return x\n",
433
+ "\n",
434
+ " def _attn_forward(self, x, mask=None, kv_cache=None):\n",
435
+ " residual = x\n",
436
+ " x = self.attn_ln(x)\n",
437
+ "\n",
438
+ " if isinstance(self.attn, HybridAttention):\n",
439
+ " attn_output = self.attn(x) \n",
440
+ "\n",
441
+ " x = residual + attn_output\n",
442
+ " else:\n",
443
+ " attn_output, _ = self.attn(x, mask=mask, kv_cache=kv_cache) \n",
444
+ " x = residual + attn_output\n",
445
+ " return x\n",
446
+ "\n",
447
+ " def _mlp_forward(self, x):\n",
448
+ " residual = x\n",
449
+ " x = self.mlp_ln(x)\n",
450
+ " return residual + self.mlp(x)\n"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "\n",
460
+ "class AudioEncoder(nn.Module):\n",
461
+ " def __init__(self, base, mels, dims, head, n_layer, n_ctx, max_dist,\n",
462
+ " win_size, max_span, hybrid, checkpoint, cross, sharpen):\n",
463
+ " super().__init__()\n",
464
+ " self.conv1 = Conv1d(in_channels=mels, out_channels=dims, kernel_size=3, padding=1)\n",
465
+ " self.conv2 = Conv1d(in_channels=dims, out_channels=dims, kernel_size=3, stride=2, padding=1)\n",
466
+ " self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)\n",
467
+ " self.checkpoint = checkpoint\n",
468
+ "\n",
469
+ " self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)\n",
470
+ "\n",
471
+ " self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])\n",
472
+ "\n",
473
+ " self.ln_post = LayerNorm(normalized_shape=dims)\n",
474
+ "\n",
475
+ " def forward(self, x):\n",
476
+ " if self.checkpoint:\n",
477
+ " x = checkpoint(self._conv_forward, x)\n",
478
+ " else:\n",
479
+ " x = self._conv_forward(x)\n",
480
+ "\n",
481
+ " for block in self.blocks:\n",
482
+ " if self.checkpoint:\n",
483
+ " x = checkpoint(block, x)\n",
484
+ " else:\n",
485
+ " x = block(x)\n",
486
+ " return self.ln_post(x)\n",
487
+ "\n",
488
+ " def _conv_forward(self, x):\n",
489
+ " x = F.gelu(self.conv1(x))\n",
490
+ " x = F.gelu(self.conv2(x))\n",
491
+ " x = x.permute(0, 2, 1)\n",
492
+ " \n",
493
+ " p = self.pos_embed(torch.arange(end=x.size(dim=1), device=x.device)).unsqueeze(0)\n",
494
+ " x = (x + p).to(x.dtype)\n",
495
+ " x = self.givens_rotary(x)\n",
496
+ " return x\n"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "metadata": {},
503
+ "outputs": [],
504
+ "source": [
505
+ "\n",
506
+ "\n",
507
+ "class TextDecoder(nn.Module):\n",
508
+ " def __init__(self, base, vocab, dims, head, n_layer, n_ctx, max_dist,\n",
509
+ " win_size, max_span, hybrid, checkpoint, cross, sharpen):\n",
510
+ " super().__init__()\n",
511
+ " \n",
512
+ " self.tok_embed = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)\n",
513
+ " self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)\n",
514
+ " self.checkpoint = checkpoint\n",
515
+ "\n",
516
+ " self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)\n",
517
+ "\n",
518
+ " self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])\n",
519
+ "\n",
520
+ " self.ln_post = LayerNorm(normalized_shape=dims)\n",
521
+ " self.ln = LayerNorm(normalized_shape=dims)\n",
522
+ "\n",
523
+ " mask = torch.empty(n_ctx, n_ctx).fill_(value=-np.inf).triu_(diagonal=1)\n",
524
+ " self.register_buffer(name=\"mask\", tensor=mask, persistent=False)\n",
525
+ " self.mask=mask\n",
526
+ "\n",
527
+ " def forward(self, x, xa, kv_cache=None):\n",
528
+ " if self.checkpoint:\n",
529
+ " x = checkpoint(self._embedding_forward, x, xa, kv_cache)\n",
530
+ " else:\n",
531
+ " x = self._embedding_forward(x=x, xa=xa, kv_cache=kv_cache)\n",
532
+ "\n",
533
+ " for block in self.blocks:\n",
534
+ " if self.checkpoint:\n",
535
+ " x = checkpoint(block, x, self.mask, kv_cache)\n",
536
+ " else:\n",
537
+ " x = block(x, self.mask, kv_cache)\n",
538
+ "\n",
539
+ " x = self.ln(x)\n",
540
+ " x = (x @ torch.transpose(input=self.tok_embed.weight.to(dtype=x.dtype), dim0=0, dim1=1)).float()\n",
541
+ " return x\n",
542
+ " \n",
543
+ " def _embedding_forward(self, x, xa, kv_cache):\n",
544
+ " offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n",
545
+ " positions = torch.arange(x.shape[1], device=x.device) + offset\n",
546
+ " pos_emb = self.pos_embed(positions).unsqueeze(0)\n",
547
+ " x = self.tok_embed(x) + pos_emb\n",
548
+ " x = self.givens_rotary(x)\n",
549
+ " return x"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": null,
555
+ "metadata": {},
556
+ "outputs": [],
557
+ "source": [
558
+ "class EchoConfig(PretrainedConfig):\n",
559
+ " model_type = \"Echo\"\n",
560
+ " def __init__(\n",
561
+ " self,\n",
562
+ " checkpoint=False,\n",
563
+ " cross=False,\n",
564
+ " hybrid=False,\n",
565
+ " sharpen=False,\n",
566
+ " a_ctx=1500,\n",
567
+ " a_head=16,\n",
568
+ " a_layer=8,\n",
569
+ " a_dims=1024,\n",
570
+ " mels=128,\n",
571
+ " t_ctx=448,\n",
572
+ " t_head=8,\n",
573
+ " t_layer=8,\n",
574
+ " t_dims=1024,\n",
575
+ " win_size=64,\n",
576
+ " max_span=64,\n",
577
+ " max_dist=64,\n",
578
+ " base=10000,\n",
579
+ " pad_token_id=50257,\n",
580
+ " unk_token_id=50257,\n",
581
+ " vocab=51865,\n",
582
+ " eos_token_id=50257,\n",
583
+ " bos_token_id=50257,\n",
584
+ " decoder_start_token_id=50258,\n",
585
+ " **kwargs,\n",
586
+ " ):\n",
587
+ " \n",
588
+ " super().__init__(**kwargs) \n",
589
+ " self.base = base\n",
590
+ " self.bos_token_id = bos_token_id\n",
591
+ " self.checkpoint = checkpoint\n",
592
+ " self.cross = cross\n",
593
+ " self.decoder_start_token_id = decoder_start_token_id\n",
594
+ " self.eos_token_id = eos_token_id\n",
595
+ " self.hybrid = hybrid\n",
596
+ " self.max_dist = max_dist\n",
597
+ " self.max_span = max_span\n",
598
+ " self.a_ctx = a_ctx\n",
599
+ " self.a_head = a_head\n",
600
+ " self.a_layer = a_layer\n",
601
+ " self.a_dims = a_dims\n",
602
+ " self.mels = mels\n",
603
+ " self.t_ctx = t_ctx\n",
604
+ " self.t_head = t_head\n",
605
+ " self.t_layer = t_layer\n",
606
+ " self.t_dims = t_dims\n",
607
+ " self.pad_token_id = pad_token_id\n",
608
+ " self.unk_token_id = unk_token_id\n",
609
+ " self.vocab = vocab\n",
610
+ " self.win_size = win_size\n",
611
+ " self.sharpen=sharpen\n",
612
+ "\n",
613
+ "class Echo(nn.Module):\n",
614
+ " def __init__(self, config: EchoConfig):\n",
615
+ " super().__init__()\n",
616
+ " self.config = config\n",
617
+ " \n",
618
+ " self.encoder = AudioEncoder(\n",
619
+ " base=self.config.base,\n",
620
+ " mels=self.config.mels,\n",
621
+ " dims=self.config.a_dims, \n",
622
+ " head=self.config.a_head,\n",
623
+ " n_layer=self.config.a_layer,\n",
624
+ " n_ctx=self.config.a_ctx,\n",
625
+ " max_dist=self.config.max_dist,\n",
626
+ " win_size=self.config.win_size, \n",
627
+ " max_span=self.config.max_span,\n",
628
+ " hybrid=self.config.hybrid,\n",
629
+ " checkpoint=self.config.checkpoint,\n",
630
+ " cross=self.config.cross,\n",
631
+ " sharpen=self.config.sharpen,\n",
632
+ " )\n",
633
+ "\n",
634
+ " self.decoder = TextDecoder(\n",
635
+ " base=self.config.base,\n",
636
+ " vocab=self.config.vocab,\n",
637
+ " dims=self.config.t_dims, \n",
638
+ " head=self.config.t_head,\n",
639
+ " n_layer=self.config.t_layer,\n",
640
+ " n_ctx=self.config.t_ctx,\n",
641
+ " max_dist=self.config.max_dist,\n",
642
+ " win_size=self.config.win_size, \n",
643
+ " max_span=self.config.max_span,\n",
644
+ " hybrid=self.config.hybrid,\n",
645
+ " checkpoint=self.config.checkpoint,\n",
646
+ " cross=self.config.cross,\n",
647
+ " sharpen=self.config.sharpen,\n",
648
+ " )\n",
649
+ "\n",
650
+ "\n",
651
+ " all_heads = torch.zeros(self.config.t_layer, self.config.t_head, dtype=torch.bool) \n",
652
+ " all_heads[self.config.t_layer // 2:] = True\n",
653
+ " self.register_buffer(name=\"alignment_heads\", tensor=all_heads.to_sparse(), persistent=False)\n",
654
+ "\n",
655
+ " self.base = self.config.base\n",
656
+ " self.win_size = self.config.win_size\n",
657
+ " self.adjust_counter = 0\n",
658
+ " self.best_loss = float('inf')\n",
659
+ " self.kv_cache = {}\n",
660
+ "\n",
661
+ "\n",
662
+ " @property\n",
663
+ " def device(self):\n",
664
+ " return next(self.parameters()).device\n",
665
+ "\n",
666
+ " def embed_audio(self, mel: torch.Tensor):\n",
667
+ " return self.encoder(mel)\n",
668
+ "\n",
669
+ " def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):\n",
670
+ " return self.decoder(tokens, audio_features)\n",
671
+ "\n",
672
+ " def update_window(self, new_window):\n",
673
+ " self.win_size = new_window\n",
674
+ " for module in self.modules(): \n",
675
+ " if isinstance(module, HybridAttention):\n",
676
+ " module.update_window(self.win_size)\n",
677
+ "\n",
678
+ " def adjust_window(self, loss, factor=1.00005):\n",
679
+ " if self.adjust_counter % 10 == 0:\n",
680
+ " if loss < self.best_loss:\n",
681
+ " new_window = self.win_size * factor\n",
682
+ " else:\n",
683
+ " new_window = self.win_size / factor\n",
684
+ " self.update_window(new_window=new_window)\n",
685
+ " self.best_loss = loss\n",
686
+ " self.adjust_counter += 1\n",
687
+ " return new_window\n",
688
+ " return self.win_size\n",
689
+ "\n",
690
+ " def adjust_base(self, loss, factor=1.0025) -> float | int:\n",
691
+ " if self.adjust_counter % 25 == 0:\n",
692
+ " if loss < self.best_loss:\n",
693
+ " new_base=self.base*factor\n",
694
+ " else:\n",
695
+ " new_base=self.base/factor\n",
696
+ " self.update_base(new_base=new_base)\n",
697
+ " self.base=new_base\n",
698
+ " self.best_loss=loss\n",
699
+ " self.adjust_counter += 1\n",
700
+ " return self.base\n",
701
+ " \n",
702
+ " def update_base(self, new_base):\n",
703
+ " self.new_base=new_base\n",
704
+ " for name, module in self.encoder.named_modules():\n",
705
+ " if isinstance(module, (CombinedRotaryEmbedding)):\n",
706
+ " module.update_base(new_base=self.new_base)\n",
707
+ "\n",
708
+ " @staticmethod\n",
709
+ " def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):\n",
710
+ " shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n",
711
+ " shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() \n",
712
+ " shifted_input_ids[:, 0] = decoder_start_token_id\n",
713
+ " shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n",
714
+ " return shifted_input_ids\n",
715
+ "\n",
716
+ " def forward(self, input_features, labels=None, dec_input_ids=None) -> dict[str, Any | None]:\n",
717
+ " if labels is not None:\n",
718
+ " if dec_input_ids is None:\n",
719
+ " dec_input_ids = self.shift_tokens_right(\n",
720
+ " input_ids=labels, pad_token_id=self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id\n",
721
+ " )\n",
722
+ "\n",
723
+ " encoded_features = self.encoder(input_features).to(self.device) \n",
724
+ " logits = self.decoder(dec_input_ids, encoded_features)\n",
725
+ "\n",
726
+ " loss = None\n",
727
+ " if labels is not None:\n",
728
+ " loss_fct = nn.CrossEntropyLoss(ignore_index=-100)\n",
729
+ " labels = labels.to(logits.device).long()\n",
730
+ " loss = loss_fct(logits.view(-1, self.config.vocab), labels.view(-1))\n",
731
+ "\n",
732
+ " self.adjust_window(loss.item())\n",
733
+ " # self.adjust_base(loss=loss.item())\n",
734
+ " return {\"loss\": loss, \"logits\": logits}\n",
735
+ "\n",
736
+ " def reset_parameters(self):\n",
737
+ " for name, module in self.encoder.named_modules():\n",
738
+ " if isinstance(module, CombinedRotaryEmbedding):\n",
739
+ " module.reset_parameters()\n",
740
+ " \n",
741
+ " def _initialize_weights(self, module):\n",
742
+ " nn.init.normal_(tensor=self.decoder.tok_embed.weight, mean=0.0, std=0.02)\n",
743
+ " nn.init.constant_(tensor=self.decoder.ln.weight, val=1)\n",
744
+ " nn.init.constant_(tensor=self.decoder.ln.bias, val=0)\n",
745
+ " nn.init.xavier_normal_(tensor=self.encoder.conv1.weight)\n",
746
+ " nn.init.zeros_(tensor=self.encoder.conv1.bias)\n",
747
+ " nn.init.kaiming_normal_(tensor=self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')\n",
748
+ " nn.init.zeros_(tensor=self.encoder.conv2.bias)\n",
749
+ " nn.init.constant_(tensor=self.encoder.ln_post.weight, val=1)\n",
750
+ " nn.init.constant_(tensor=self.encoder.ln_post.bias, val=0)\n",
751
+ "\n",
752
+ " for block in self.decoder.blocks:\n",
753
+ " for layer in block.children():\n",
754
+ " if isinstance(layer, nn.Linear):\n",
755
+ " nn.init.xavier_normal_(tensor=layer.weight)\n",
756
+ " nn.init.zeros_(tensor=layer.bias)\n",
757
+ " if isinstance(layer, LayerNorm):\n",
758
+ " nn.init.constant_(tensor=layer.weight, val=1)\n",
759
+ " \n",
760
+ " for block in self.encoder.blocks:\n",
761
+ " for layer in block.children():\n",
762
+ " if isinstance(layer, nn.Linear):\n",
763
+ " nn.init.xavier_normal_(tensor=layer.weight)\n",
764
+ " nn.init.zeros_(tensor=layer.bias)\n",
765
+ " if isinstance(layer, LayerNorm):\n",
766
+ " nn.init.constant_(tensor=layer.weight, val=1)\n",
767
+ "\n",
768
+ " for module in self.encoder.named_modules():\n",
769
+ " if isinstance(module, CombinedRotaryEmbedding):\n",
770
+ " nn.init.constant_(tensor=module.thetas, val=1)\n",
771
+ " nn.init.constant_(tensor=module.r_matrix, val=1)\n",
772
+ " nn.init.constant_(tensor=module.r_pairs, val=1)\n",
773
+ " nn.init.constant_(tensor=module.inv_freq, val=1)\n",
774
+ "\n",
775
+ " def apply_initialization(self, module):\n",
776
+ " self._initialize_weights(module=module)\n"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "metadata": {},
783
+ "outputs": [],
784
+ "source": [
785
+ "\n",
786
+ "from datetime import datetime\n",
787
+ "log_dir = os.path.join('./output/Echo/', datetime.now().strftime(format='%m-%d_%H'))\n",
788
+ "os.makedirs(name=log_dir, exist_ok=True)\n",
789
+ "\n",
790
+ "config = EchoConfig(\n",
791
+ " checkpoint=False,\n",
792
+ " cross=False,\n",
793
+ " hybrid=False,\n",
794
+ " sharpen=False,\n",
795
+ " audio_ctx=1500,\n",
796
+ " audio_head=4,\n",
797
+ " audio_layer=4,\n",
798
+ " audio_dims=512,\n",
799
+ " mels=128,\n",
800
+ " text_ctx=448,\n",
801
+ " text_head=4,\n",
802
+ " text_layer=4,\n",
803
+ " text_dims=512,\n",
804
+ " win_size=16,\n",
805
+ " max_span=16,\n",
806
+ " max_dist=16,\n",
807
+ " base=50000,\n",
808
+ " pad_token_id=50257,\n",
809
+ " unk_token_id=50257,\n",
810
+ " vocab=51865,\n",
811
+ " eos_token_id=50257,\n",
812
+ " bos_token_id=50257,\n",
813
+ " decoder_start_token_id=50258,\n",
814
+ ")\n",
815
+ "\n",
816
+ "model = Echo(config=config).to(device=device)\n",
817
+ "model.apply_initialization(module=model)"
818
+ ]
819
+ },
820
+ {
821
+ "cell_type": "code",
822
+ "execution_count": null,
823
+ "metadata": {},
824
+ "outputs": [],
825
+ "source": [
826
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\n",
827
+ " pretrained_model_name_or_path=\"openai/whisper-small\", \n",
828
+ " feature_size=128, sample_rate=160000, do_normalize=True)\n",
829
+ "\n",
830
+ "tokenizer = WhisperTokenizerFast.from_pretrained(\n",
831
+ " pretrained_model_name_or_path=\"openai/whisper-small\", \n",
832
+ " language=\"en\", task=\"transcribe\")\n",
833
+ "\n",
834
+ "processor = WhisperProcessor.from_pretrained(\n",
835
+ " pretrained_model_name_or_path=\"openai/whisper-small\", \n",
836
+ " feature_size=128, sample_rate=160000, do_normalize=True, \n",
837
+ " language=\"en\", task=\"transcribe\")\n",
838
+ "\n",
839
+ "class GradientClippingCallback(TrainerCallback):\n",
840
+ " def on_step_end(self, args, dims, control, **kwargs):\n",
841
+ " torch.nn.utils.clip_grad_norm_(parameters=kwargs[\"model\"].parameters(), max_norm=0.98)\n",
842
+ "\n",
843
+ "@dataclass\n",
844
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
845
+ " processor: Any\n",
846
+ " decoder_start_token_id: int\n",
847
+ "\n",
848
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
849
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
850
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
851
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
852
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
853
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
854
+ " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n",
855
+ " labels = labels[:, 1:]\n",
856
+ " batch[\"labels\"] = labels\n",
857
+ " return batch\n",
858
+ "\n",
859
+ "def get_length_of_dataset(dataset):\n",
860
+ " length = 0\n",
861
+ " for item in dataset:\n",
862
+ " length += len(item[\"audio\"][\"array\"]) / item[\"audio\"][\"sampling_rate\"]\n",
863
+ " return length / 3600 \n",
864
+ "\n",
865
+ "def prepare_dataset(batch):\n",
866
+ " audio = batch[\"audio\"]\n",
867
+ " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
868
+ " batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
869
+ " return batch\n",
870
+ "\n",
871
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=config.decoder_start_token_id)\n",
872
+ "\n",
873
+ "datasets = IterableDatasetDict()\n",
874
+ "\n",
875
+ "datasets[\"train\"] = load_dataset(\n",
876
+ " path=\"mozilla-foundation/common_voice_17_0\", token=\"\",\n",
877
+ " name=\"en\", split=\"train\", streaming=True, trust_remote_code=True).take(10000)\n",
878
+ "\n",
879
+ "datasets[\"test\"] = load_dataset(\n",
880
+ " path=\"mozilla-foundation/common_voice_17_0\", token=\"\", \n",
881
+ " name=\"en\", split=\"test\", streaming=True, trust_remote_code=True).take(100)\n",
882
+ "\n",
883
+ "dataset = datasets.cast_column(column=\"audio\", feature=Audio(sampling_rate=16000))\n",
884
+ "\n",
885
+ "dataset = dataset.map(function=prepare_dataset, \n",
886
+ " remove_columns=list(next(iter(dataset.values())).features)).with_format(type=\"torch\")\n",
887
+ "\n",
888
+ "class MetricsCallback(TrainerCallback):\n",
889
+ " def __init__(self, tb_writer, tokenizer, metric, optimizer, scheduler, log_every_n_steps=1):\n",
890
+ " super().__init__()\n",
891
+ " self.tb_writer = tb_writer\n",
892
+ " self.tokenizer = tokenizer\n",
893
+ " self.metric = metric\n",
894
+ " self.optimizer = optimizer\n",
895
+ " self.scheduler = scheduler\n",
896
+ " self.log_every_n_steps = log_every_n_steps\n",
897
+ " self.predictions = None\n",
898
+ " self.label_ids = None\n",
899
+ "\n",
900
+ " def compute_wer(self, pred_str, label_str):\n",
901
+ " wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)\n",
902
+ " return wer\n",
903
+ "\n",
904
+ " def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):\n",
905
+ " if metrics is not None:\n",
906
+ " self.eval_loss = metrics.get('eval_loss')\n",
907
+ "\n",
908
+ " current_learning_rate = self.optimizer.param_groups[0]['lr']\n",
909
+ " if state.global_step % self.log_every_n_steps == 0:\n",
910
+ " self.tb_writer.add_scalar('learning_rate', current_learning_rate, state.global_step)\n",
911
+ " print(f\"Learning Rate: {current_learning_rate:.8f}\")\n",
912
+ "\n",
913
+ " self.tb_writer.add_scalar('eval_loss', self.eval_loss, state.global_step)\n",
914
+ "\n",
915
+ " for key, value in metrics.items():\n",
916
+ " if key.startswith(\"eval_\"):\n",
917
+ " self.tb_writer.add_scalar(key, value, state.global_step)\n",
918
+ "\n",
919
+ " if self.predictions is not None and self.label_ids is not None:\n",
920
+ " pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)\n",
921
+ " label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)\n",
922
+ "\n",
923
+ " if state.global_step % self.log_every_n_steps == 0:\n",
924
+ " total_samples = len(pred_str)\n",
925
+ " random_indices = random.sample(range(total_samples), 1)\n",
926
+ "\n",
927
+ " for sample_index in random_indices:\n",
928
+ " self.tb_writer.add_text(f\"Prediction_{sample_index}\", pred_str[sample_index], state.global_step)\n",
929
+ " self.tb_writer.add_text(f\"Label_{sample_index}\", label_str[sample_index], state.global_step)\n",
930
+ " print(f\"Evaluation: - Step {state.global_step} - Loss: {self.eval_loss:.2f}\")\n",
931
+ " print(f\"Prediction: {pred_str[sample_index]}\")\n",
932
+ " print(f\"Label: {label_str[sample_index]}\")\n",
933
+ " print(\"-\" * 10)\n",
934
+ "\n",
935
+ " self.predictions = None\n",
936
+ " self.label_ids = None\n",
937
+ "\n",
938
+ "def create_compute_metrics(callback_instance):\n",
939
+ " def compute_metrics(eval_pred):\n",
940
+ " pred_logits = eval_pred.predictions\n",
941
+ " label_ids = eval_pred.label_ids\n",
942
+ "\n",
943
+ " if isinstance(pred_logits, tuple):\n",
944
+ " pred_ids = pred_logits[0]\n",
945
+ " else:\n",
946
+ " pred_ids = pred_logits\n",
947
+ " if pred_ids.ndim == 3:\n",
948
+ " pred_ids = np.argmax(pred_ids, axis=-1)\n",
949
+ "\n",
950
+ " label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id\n",
951
+ " callback_instance.predictions = pred_ids\n",
952
+ " callback_instance.label_ids = label_ids\n",
953
+ " pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
954
+ " label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
955
+ " wer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)\n",
956
+ " pred_flat = pred_ids.flatten()\n",
957
+ " labels_flat = label_ids.flatten()\n",
958
+ " mask = labels_flat != callback_instance.tokenizer.pad_token_id\n",
959
+ "\n",
960
+ " accuracy = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])\n",
961
+ " precision = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
962
+ " recall = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
963
+ " f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
964
+ " return {\"wer\": wer, \"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
965
+ " return compute_metrics\n",
966
+ "\n",
967
+ "metric = evaluate.load(path=\"wer\")\n",
968
+ "tb_writer = SummaryWriter(log_dir=log_dir)\n",
969
+ "\n",
970
+ "training_args = Seq2SeqTrainingArguments(\n",
971
+ " output_dir=log_dir,\n",
972
+ " per_device_train_batch_size=1,\n",
973
+ " per_device_eval_batch_size=1,\n",
974
+ " gradient_accumulation_steps=1,\n",
975
+ " eval_accumulation_steps=1,\n",
976
+ " tf32=True,\n",
977
+ " bf16=True,\n",
978
+ " eval_strategy=\"steps\",\n",
979
+ " save_strategy=\"steps\",\n",
980
+ " max_steps=10000,\n",
981
+ " save_steps=10000,\n",
982
+ " eval_steps=100,\n",
983
+ " warmup_steps=100,\n",
984
+ " logging_steps=10,\n",
985
+ " logging_dir=log_dir + \"/logs_hf\",\n",
986
+ " report_to=[\"tensorboard\"],\n",
987
+ " load_best_model_at_end=False,\n",
988
+ " metric_for_best_model=\"loss\",\n",
989
+ " greater_is_better=False,\n",
990
+ " push_to_hub=False,\n",
991
+ " disable_tqdm=False,\n",
992
+ " save_total_limit=1,\n",
993
+ " remove_unused_columns=False,\n",
994
+ " label_names=[\"labels\"],\n",
995
+ " eval_on_start=True,\n",
996
+ ")\n",
997
+ "\n",
998
+ "class MaxFactor(Optimizer):\n",
999
+ " def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(None, 1e-3), d=1.0, \n",
1000
+ " weight_decay=0.0, gamma=0.99, eps_rms=1e-8, maximize=False):\n",
1001
+ " \n",
1002
+ " defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, \n",
1003
+ " gamma=gamma, eps_rms=eps_rms, maximize=maximize)\n",
1004
+ "\n",
1005
+ " super().__init__(params, defaults)\n",
1006
+ "\n",
1007
+ " @torch.no_grad()\n",
1008
+ " def step(self, closure=None):\n",
1009
+ " loss = None\n",
1010
+ " if closure is not None:\n",
1011
+ " with torch.enable_grad():\n",
1012
+ " loss = closure()\n",
1013
+ "\n",
1014
+ " for group in self.param_groups:\n",
1015
+ " params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []\n",
1016
+ " eps1, eps2 = group[\"eps\"]\n",
1017
+ " for p in group[\"params\"]:\n",
1018
+ " if p.grad is None:\n",
1019
+ " continue\n",
1020
+ " grad = p.grad\n",
1021
+ " if grad.dtype in {torch.float16, torch.bfloat16}:\n",
1022
+ " grad = grad.float()\n",
1023
+ "\n",
1024
+ " state = self.state[p]\n",
1025
+ " if len(state) == 0:\n",
1026
+ " state[\"step\"] = torch.tensor(0.0, dtype=torch.float32)\n",
1027
+ " if p.grad.dim() > 1:\n",
1028
+ " row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)\n",
1029
+ " row_shape[-1], col_shape[-2] = 1, 1\n",
1030
+ " state[\"row_var\"], state[\"col_var\"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)\n",
1031
+ " state[\"v\"] = torch.zeros_like(p, memory_format=torch.preserve_format)\n",
1032
+ "\n",
1033
+ " row_vars.append(state.get(\"row_var\", None))\n",
1034
+ " col_vars.append(state.get(\"col_var\", None))\n",
1035
+ " v.append(state[\"v\"])\n",
1036
+ " state_steps.append(state[\"step\"])\n",
1037
+ " params_with_grad.append(p)\n",
1038
+ " grads.append(grad)\n",
1039
+ "\n",
1040
+ " for i, param in enumerate(params_with_grad):\n",
1041
+ " grad = grads[i]\n",
1042
+ "\n",
1043
+ " if group[\"maximize\"]:\n",
1044
+ " grad = -grad\n",
1045
+ " step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]\n",
1046
+ "\n",
1047
+ " if eps1 is None:\n",
1048
+ " eps1 = torch.finfo(param.dtype).eps\n",
1049
+ " \n",
1050
+ " step_t += 1\n",
1051
+ " step_float = step_t.item()\n",
1052
+ " one_minus_beta2_t = step_float ** group[\"beta2_decay\"]\n",
1053
+ " rho_t = min(group[\"lr\"], 1 / (step_float ** 0.5))\n",
1054
+ " alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t\n",
1055
+ "\n",
1056
+ " if group[\"weight_decay\"]!= 0:\n",
1057
+ " param.mul_(1 - group[\"lr\"] * group[\"weight_decay\"])\n",
1058
+ "\n",
1059
+ " if grad.dim() > 1:\n",
1060
+ " row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))\n",
1061
+ " row_var.lerp_(row_mean, one_minus_beta2_t)\n",
1062
+ " col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))\n",
1063
+ " col_var.lerp_(col_mean, one_minus_beta2_t)\n",
1064
+ " var_estimate = row_var @ col_var\n",
1065
+ " max_row_var = row_var.max(dim=-2, keepdim=True)[0] \n",
1066
+ " var_estimate.div_(max_row_var.clamp_(min=eps1))\n",
1067
+ "\n",
1068
+ " else:\n",
1069
+ " vi.mul_(group[\"gamma\"]).add_(1 - group[\"gamma\"], grad ** 2)\n",
1070
+ " var_estimate = vi\n",
1071
+ " \n",
1072
+ " update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)\n",
1073
+ " update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))\n",
1074
+ " denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group[\"d\"]))\n",
1075
+ " param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])\n",
1076
+ "\n",
1077
+ " return loss\n",
1078
+ " \n",
1079
+ "optimizer = MaxFactor(\n",
1080
+ " model.parameters(), \n",
1081
+ " lr=0.025, \n",
1082
+ " beta2_decay=-0.8,\n",
1083
+ " eps=(None, 1e-4),\n",
1084
+ " d=1.0,\n",
1085
+ " weight_decay=0.0025,\n",
1086
+ " gamma=0.99, \n",
1087
+ " eps_rms=1e-8,\n",
1088
+ " maximize=False,\n",
1089
+ " )\n",
1090
+ "\n",
1091
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
1092
+ " optimizer=optimizer,\n",
1093
+ " T_max=training_args.max_steps,\n",
1094
+ " eta_min=0.0,\n",
1095
+ " last_epoch=-1 \n",
1096
+ ")\n",
1097
+ "\n",
1098
+ "metrics_callback = MetricsCallback(tb_writer=tb_writer, tokenizer=tokenizer, metric=metric, optimizer=optimizer, scheduler=scheduler, log_every_n_steps=10)\n",
1099
+ "compute_metrics = create_compute_metrics(callback_instance=metrics_callback)\n",
1100
+ "\n",
1101
+ "trainer = Seq2SeqTrainer(\n",
1102
+ " args=training_args,\n",
1103
+ " model=model,\n",
1104
+ " train_dataset=dataset[\"train\"],\n",
1105
+ " eval_dataset=dataset[\"test\"],\n",
1106
+ " data_collator=data_collator,\n",
1107
+ " compute_metrics=compute_metrics,\n",
1108
+ " processing_class=feature_extractor,\n",
1109
+ " callbacks=[metrics_callback],\n",
1110
+ " optimizers=(optimizer, scheduler)\n",
1111
+ ")"
1112
+ ]
1113
+ },
1114
+ {
1115
+ "cell_type": "code",
1116
+ "execution_count": null,
1117
+ "metadata": {},
1118
+ "outputs": [],
1119
+ "source": [
1120
+ "\n",
1121
+ "trainer.train(resume_from_checkpoint=False)"
1122
+ ]
1123
+ },
1124
+ {
1125
+ "cell_type": "code",
1126
+ "execution_count": null,
1127
+ "metadata": {},
1128
+ "outputs": [],
1129
+ "source": [
1130
+ "from tensorboard import program\n",
1131
+ "log_dir = \"D:/new/tensorboard3\" \n",
1132
+ "tb = program.TensorBoard()\n",
1133
+ "tb.configure(argv=[None, '--logdir', log_dir])\n",
1134
+ "url = tb.launch()\n",
1135
+ "print(f\"TensorBoard started at {url}\")"
1136
+ ]
1137
+ }
1138
+ ],
1139
+ "metadata": {
1140
+ "kernelspec": {
1141
+ "display_name": "Python 3",
1142
+ "language": "python",
1143
+ "name": "python3"
1144
+ },
1145
+ "language_info": {
1146
+ "codemirror_mode": {
1147
+ "name": "ipython",
1148
+ "version": 3
1149
+ },
1150
+ "file_extension": ".py",
1151
+ "mimetype": "text/x-python",
1152
+ "name": "python",
1153
+ "nbconvert_exporter": "python",
1154
+ "pygments_lexer": "ipython3",
1155
+ "version": "3.12.8"
1156
+ }
1157
+ },
1158
+ "nbformat": 4,
1159
+ "nbformat_minor": 2
1160
+ }