Create echo.py
Browse files
echo.py
ADDED
@@ -0,0 +1,1032 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import base64, gzip, math, os, functools, warnings, numpy as np, torch, transformers, aiohttp, torch.nn.functional as F, evaluate, json, random
|
3 |
+
from torch import Tensor, amp, optim, nn
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
+
from threading import Thread
|
7 |
+
from typing import Dict, Optional, Tuple, Union, List, Any
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments, PretrainedConfig, TrainerCallback, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizerFast)
|
10 |
+
from torch.optim import Optimizer
|
11 |
+
import evaluate
|
12 |
+
from evaluate import module
|
13 |
+
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
|
14 |
+
from datasets import load_dataset, IterableDatasetDict, Audio, load_from_disk
|
15 |
+
from torch.nn.functional import scaled_dot_product_attention
|
16 |
+
transformers.utils.logging.set_verbosity_error()
|
17 |
+
warnings.filterwarnings(action="ignore")
|
18 |
+
warnings.warn = lambda *args, **kwargs: None
|
19 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
20 |
+
dtype = torch.float32
|
21 |
+
|
22 |
+
|
23 |
+
class Linear(nn.Linear):
|
24 |
+
def forward(self, x: Tensor) -> Tensor:# type: ignore
|
25 |
+
return F.linear(x, self.weight.to(x.dtype),
|
26 |
+
None if self.bias is None else self.bias.to(x.dtype))
|
27 |
+
|
28 |
+
class Conv1d(nn.Conv1d):
|
29 |
+
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:# type: ignore
|
30 |
+
return super()._conv_forward(x, weight.to(x.dtype),
|
31 |
+
None if bias is None else bias.to(x.dtype))
|
32 |
+
|
33 |
+
class LayerNorm(nn.LayerNorm):
|
34 |
+
def forward(self, x: Tensor) -> Tensor: # type: ignore
|
35 |
+
return super().forward(x.float()).type(x.dtype)
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class CombinedRotaryEmbedding(nn.Module):
|
40 |
+
def __init__(self, base, dims, head, theta_learnable=True, rot_learnable=True,
|
41 |
+
matrix_learnable=False, freq_learnable=True):
|
42 |
+
super(CombinedRotaryEmbedding, self).__init__()
|
43 |
+
|
44 |
+
self.base = base
|
45 |
+
self.dims = dims
|
46 |
+
self.head = head
|
47 |
+
|
48 |
+
self.h_dim = self.dims // self.head
|
49 |
+
self.rot = (self.dims // self.head) // 2
|
50 |
+
|
51 |
+
self.thetas = nn.Parameter(torch.zeros(self.rot))
|
52 |
+
self.r_pairs = nn.Parameter(data=torch.rand(self.rot, 2) * self.h_dim)
|
53 |
+
|
54 |
+
self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)
|
55 |
+
self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)
|
56 |
+
|
57 |
+
self.r_matrix = nn.Parameter(torch.eye(n=self.h_dim), requires_grad=matrix_learnable)
|
58 |
+
|
59 |
+
freq_data = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
|
60 |
+
self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)
|
61 |
+
|
62 |
+
self.orthogonal_reg_weight = 0.01
|
63 |
+
|
64 |
+
def blended_rotation_matrix(self, dims, i, j, theta):
|
65 |
+
G = torch.eye(dims).to(theta.device)
|
66 |
+
G[i, i] = torch.cos(theta)
|
67 |
+
G[i, j] = -torch.sin(theta)
|
68 |
+
G[j, i] = torch.sin(theta)
|
69 |
+
G[j, j] = torch.cos(theta)
|
70 |
+
|
71 |
+
v = torch.zeros(dims).to(theta.device)
|
72 |
+
v[i] = torch.cos(theta)
|
73 |
+
v[j] = torch.sin(theta)
|
74 |
+
H = torch.eye(dims).to(theta.device) - 2 * torch.outer(v, v) / torch.dot(v, v)
|
75 |
+
|
76 |
+
R = torch.eye(dims).to(theta.device)
|
77 |
+
R[i, i] = torch.cos(theta)
|
78 |
+
R[i, j] = -torch.sin(theta)
|
79 |
+
R[j, i] = torch.sin(theta)
|
80 |
+
R[j, j] = torch.cos(theta)
|
81 |
+
|
82 |
+
return (G + H + R) / 3
|
83 |
+
|
84 |
+
def apply_blended_rotation(self, x):
|
85 |
+
adjusted_rot = int(torch.round(self.rot_scale * self.rot))
|
86 |
+
for k in range(adjusted_rot):
|
87 |
+
i, j = self.r_pairs[k].long()
|
88 |
+
theta = self.thetas[k] * self.theta_scale
|
89 |
+
B = self.blended_rotation_matrix(dims=self.h_dim, i=i, j=j, theta=theta)
|
90 |
+
x = torch.matmul(input=x, other=B)
|
91 |
+
return x
|
92 |
+
|
93 |
+
def update_base(self, new_base):
|
94 |
+
if new_base is not None and new_base != self.base:
|
95 |
+
self.base = new_base
|
96 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
|
97 |
+
self.inv_freq.data.copy_(inv_freq)
|
98 |
+
self.update_pairs()
|
99 |
+
|
100 |
+
def reset_parameters(self):
|
101 |
+
nn.init.orthogonal_(self.r_matrix)
|
102 |
+
nn.init.zeros_(self.thetas)
|
103 |
+
nn.init.zeros_(self.r_pairs)
|
104 |
+
nn.init.ones_(self.theta_scale)
|
105 |
+
nn.init.ones_(self.rot_scale)
|
106 |
+
|
107 |
+
def orthogonal_regularization_term(self):
|
108 |
+
loss = torch.tensor(0.0, device=self.r_matrix.device)
|
109 |
+
if self.r_matrix.requires_grad:
|
110 |
+
product = torch.matmul(self.r_matrix, self.r_matrix.t())
|
111 |
+
identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
|
112 |
+
loss = ((product - identity) ** 2).sum()
|
113 |
+
return self.orthogonal_reg_weight * loss
|
114 |
+
|
115 |
+
def update_pairs(self):
|
116 |
+
pairs = []
|
117 |
+
while len(pairs) < self.rot:
|
118 |
+
i, j = torch.randint(0, self.h_dim - 1, (2,))
|
119 |
+
if i != j and (i, j) not in pairs and (j, i) not in pairs:
|
120 |
+
pairs.append((i, j))
|
121 |
+
self.r_pairs.data.copy_(torch.tensor(pairs, dtype=torch.float32))
|
122 |
+
|
123 |
+
def forward(self, x, global_step=None):
|
124 |
+
if x.dim() not in [3, 4]:
|
125 |
+
raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
|
126 |
+
|
127 |
+
batch_size, seq_len, *rest = x.size()
|
128 |
+
|
129 |
+
if x.dim() == 3:
|
130 |
+
dims = rest[0]
|
131 |
+
if dims != self.head * self.h_dim:
|
132 |
+
raise ValueError(f"Expected dims ({dims}) to be compatible with head ({self.head}) * h_dim ({self.h_dim}={self.head * self.h_dim})")
|
133 |
+
else:
|
134 |
+
head, h_dim = rest
|
135 |
+
if head != self.head or h_dim != self.h_dim:
|
136 |
+
raise ValueError(f"For 4D input, expected head {self.head} and h_dim {self.h_dim}, but got head {head} and h_dim {h_dim}")
|
137 |
+
|
138 |
+
x = x.view(batch_size, seq_len, self.head, self.h_dim)
|
139 |
+
x = x.reshape(-1, self.h_dim)
|
140 |
+
|
141 |
+
x = self.apply_blended_rotation(x)
|
142 |
+
|
143 |
+
x = torch.matmul(input=x, other=self.r_matrix)
|
144 |
+
|
145 |
+
x = x.view(batch_size, seq_len, self.head, self.h_dim)
|
146 |
+
|
147 |
+
sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
|
148 |
+
sin = sinusoid_inp.sin()[None, :, None, :]
|
149 |
+
cos = sinusoid_inp.cos()[None, :, None, :]
|
150 |
+
|
151 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
152 |
+
x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
153 |
+
x = x.view(batch_size, seq_len, self.dims)
|
154 |
+
|
155 |
+
return x
|
156 |
+
|
157 |
+
class SinusoidalEmbedding(nn.Module):
|
158 |
+
def __init__(self, n_ctx, dims, checkpoint):
|
159 |
+
super().__init__()
|
160 |
+
self.n_ctx = n_ctx
|
161 |
+
self.dims = dims
|
162 |
+
self.checkpoint = checkpoint
|
163 |
+
|
164 |
+
position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
|
165 |
+
div_term = torch.exp(torch.arange(0, dims, 2).float() * -(math.log(10000.0) / dims))
|
166 |
+
features = torch.zeros(n_ctx, dims)
|
167 |
+
features[:, 0::2] = torch.sin(position * div_term)
|
168 |
+
features[:, 1::2] = torch.cos(position * div_term)
|
169 |
+
self.register_buffer('my_big_toe', features)
|
170 |
+
self.pos_embeds = nn.Parameter(self.my_big_toe.clone())
|
171 |
+
|
172 |
+
def forward(self, positions):
|
173 |
+
if self.checkpoint:
|
174 |
+
position_embeddings = checkpoint(lambda x: self.pos_embeds[x], positions)
|
175 |
+
else:
|
176 |
+
position_embeddings = self.pos_embeds[positions]
|
177 |
+
return F.normalize(position_embeddings, p=2, dim=-1)
|
178 |
+
|
179 |
+
class CombinedPositionalEmbedding(nn.Module):
|
180 |
+
def __init__(self, base, dims, head, n_ctx, theta_learnable=True, rot_learnable=True,
|
181 |
+
matrix_learnable=False, freq_learnable=True, checkpoint=False):
|
182 |
+
super().__init__()
|
183 |
+
self.rotary_embedding = CombinedRotaryEmbedding(base, dims, head, theta_learnable,
|
184 |
+
rot_learnable, matrix_learnable, freq_learnable)
|
185 |
+
self.sinusoidal_embedding = SinusoidalEmbedding(n_ctx, dims, checkpoint)
|
186 |
+
|
187 |
+
def forward(self, x, positions, global_step=None):
|
188 |
+
rotary_embed = self.rotary_embedding(x, global_step)
|
189 |
+
sinusoidal_embed = self.sinusoidal_embedding(positions)
|
190 |
+
|
191 |
+
combined_embedding = rotary_embed + sinusoidal_embed
|
192 |
+
return combined_embedding
|
193 |
+
|
194 |
+
class MultiheadAttention(nn.Module):
|
195 |
+
use_sdpa = True
|
196 |
+
|
197 |
+
def __init__(self, base, dims, head, max_dist):
|
198 |
+
super().__init__()
|
199 |
+
assert dims % head == 0, "dims must be divisible by head"
|
200 |
+
self.head = head
|
201 |
+
self.h_dim = dims // head
|
202 |
+
assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
|
203 |
+
|
204 |
+
self.query = nn.Linear(dims, dims)
|
205 |
+
self.key = nn.Linear(dims, dims, bias=False)
|
206 |
+
self.value = nn.Linear(dims, dims)
|
207 |
+
self.out = nn.Linear(dims, dims)
|
208 |
+
|
209 |
+
# self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
|
210 |
+
|
211 |
+
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
212 |
+
|
213 |
+
q = self.query(x)
|
214 |
+
|
215 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
216 |
+
k = self.key(x if xa is None else xa)
|
217 |
+
v = self.value(x if xa is None else xa)
|
218 |
+
|
219 |
+
else:
|
220 |
+
k = kv_cache[self.key]
|
221 |
+
v = kv_cache[self.value]
|
222 |
+
|
223 |
+
# q = self.givens_rotary(q)
|
224 |
+
# k = self.givens_rotary(k)
|
225 |
+
|
226 |
+
wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)
|
227 |
+
|
228 |
+
out = self.out(wv)
|
229 |
+
return out, qk
|
230 |
+
|
231 |
+
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
232 |
+
|
233 |
+
n_batch, n_ctx, dims = q.shape
|
234 |
+
scale = (dims // self.head) ** -0.25
|
235 |
+
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
236 |
+
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
237 |
+
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
238 |
+
|
239 |
+
if MultiheadAttention.use_sdpa:
|
240 |
+
a = scaled_dot_product_attention(query=q, key=k, value=v, is_causal=mask is not None and n_ctx > 1)
|
241 |
+
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
242 |
+
qk = None
|
243 |
+
else:
|
244 |
+
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
245 |
+
if mask is not None:
|
246 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
247 |
+
qk = qk.float()
|
248 |
+
|
249 |
+
w = F.softmax(qk, dim=-1).to(dtype=q.dtype)
|
250 |
+
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
251 |
+
qk = qk.detach()
|
252 |
+
|
253 |
+
return out, qk
|
254 |
+
|
255 |
+
class AdaptiveSpanAttention(nn.Module):
|
256 |
+
def __init__(self, base, dims, head, max_dist, sharpen, win_size, max_span, temp_scale=0.01):
|
257 |
+
super().__init__()
|
258 |
+
self.max_dist = max_dist
|
259 |
+
self.win_size = win_size
|
260 |
+
self.max_span = max_span
|
261 |
+
self.temp_scale = temp_scale
|
262 |
+
self.multihead_attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
|
263 |
+
self.span_scale = nn.Parameter(torch.tensor(1.0))
|
264 |
+
self.sharpen = sharpen
|
265 |
+
|
266 |
+
def forward(self, query, key, value, span_scale):
|
267 |
+
span_len = int(self.max_span * span_scale.mean().item())
|
268 |
+
span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])
|
269 |
+
eff_span = min(span_len, self.max_dist)
|
270 |
+
|
271 |
+
q_span = query[:, :eff_span, :]
|
272 |
+
k_span = key[:, :eff_span, :]
|
273 |
+
v_span = value[:, :eff_span, :]
|
274 |
+
|
275 |
+
batch_size, _, dims = query.shape
|
276 |
+
scale = (dims // self.multihead_attn.head) ** -0.25
|
277 |
+
|
278 |
+
q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
|
279 |
+
k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
|
280 |
+
v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
|
281 |
+
|
282 |
+
if self.sharpen:
|
283 |
+
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
284 |
+
else:
|
285 |
+
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
|
286 |
+
|
287 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1))
|
288 |
+
attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)
|
289 |
+
attn_out = torch.matmul(attn_weights, v)
|
290 |
+
attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
291 |
+
attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)
|
292 |
+
|
293 |
+
return attn_out, attn_weights
|
294 |
+
|
295 |
+
class SpanPredictor(nn.Module):
|
296 |
+
def __init__(self, dims):
|
297 |
+
super().__init__()
|
298 |
+
self.linear = nn.Linear(in_features=dims, out_features=1)
|
299 |
+
|
300 |
+
def forward(self, global_out):
|
301 |
+
scale = torch.sigmoid(self.linear(global_out))
|
302 |
+
return scale
|
303 |
+
|
304 |
+
class HybridAttention(nn.Module):
|
305 |
+
def __init__(self, base, dims, head, max_dist, sharpen, win_size=32, max_span=32, slid_win=32):
|
306 |
+
super().__init__()
|
307 |
+
self.max_dist = max_dist
|
308 |
+
self.win_size = win_size
|
309 |
+
self.max_span = max_span
|
310 |
+
self.slid_win = slid_win
|
311 |
+
|
312 |
+
self.span_pred = SpanPredictor(dims=dims)
|
313 |
+
self.dist_local = max_dist
|
314 |
+
self.dist_global = max_dist
|
315 |
+
|
316 |
+
self.attn_local = AdaptiveSpanAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen, win_size=win_size, max_span=max_span)
|
317 |
+
self.attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=self.dist_global)
|
318 |
+
self.ln_local = LayerNorm(normalized_shape=dims)
|
319 |
+
self.ln_global = LayerNorm(normalized_shape=dims)
|
320 |
+
self.projection = Linear(in_features=2 * dims, out_features=dims)
|
321 |
+
|
322 |
+
def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):
|
323 |
+
local = self.ln_local(x)
|
324 |
+
globe = self.ln_global(x)
|
325 |
+
|
326 |
+
globe_out, _ = self.attn_global(globe, globe, globe)
|
327 |
+
|
328 |
+
span_scale = self.span_pred(globe_out.mean(dim=1))
|
329 |
+
|
330 |
+
win_size = max(1, int(self.slid_win * span_scale.mean().item()))
|
331 |
+
span_len = max(1, int(self.max_span * span_scale.mean().item()))
|
332 |
+
|
333 |
+
effective_max_dist = min(self.max_dist, local.size(1))
|
334 |
+
local_max_dist = min(self.dist_local, span_len, win_size)
|
335 |
+
globe_max_dist = effective_max_dist
|
336 |
+
|
337 |
+
self.attn_local.max_dist = local_max_dist
|
338 |
+
self.attn_global.max_dist = globe_max_dist
|
339 |
+
|
340 |
+
local_out = self.slide_win(x=local, win_size=win_size, span_len=span_len, span_scale=span_scale)
|
341 |
+
|
342 |
+
combined = torch.cat(tensors=[local_out, globe_out], dim=-1)
|
343 |
+
x = self.projection(combined)
|
344 |
+
|
345 |
+
return x
|
346 |
+
|
347 |
+
def slide_win(self, x, win_size, span_len, span_scale):
|
348 |
+
batch_size, seq_len, dims = x.size()
|
349 |
+
out = torch.zeros_like(x, device=x.device)
|
350 |
+
|
351 |
+
for i in range(0, seq_len, win_size):
|
352 |
+
end = min(i + win_size, seq_len)
|
353 |
+
query = x[:, i:end, :]
|
354 |
+
|
355 |
+
start = max(0, i - span_len + win_size)
|
356 |
+
key = x[:, start:i + span_len, :]
|
357 |
+
value = x[:, start:i + span_len, :]
|
358 |
+
attn_out, _ = self.attn_local(query, key, value, span_scale)
|
359 |
+
out[:, i:end, :] = attn_out
|
360 |
+
|
361 |
+
return out
|
362 |
+
|
363 |
+
class ResidualAttention(nn.Module):
|
364 |
+
def __init__(self, base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen):
|
365 |
+
super().__init__()
|
366 |
+
|
367 |
+
if hybrid:
|
368 |
+
self.attn = HybridAttention(base=base, dims=dims, head=head, max_dist=max_dist, sharpen=sharpen)
|
369 |
+
self.attn_ln = LayerNorm(normalized_shape=dims)
|
370 |
+
else:
|
371 |
+
self.attn = MultiheadAttention(base=base, dims=dims, head=head, max_dist=max_dist)
|
372 |
+
self.attn_ln = LayerNorm(normalized_shape=dims)
|
373 |
+
|
374 |
+
n_mlp = dims * 4
|
375 |
+
self.mlp = nn.Sequential(Linear(in_features=dims, out_features=n_mlp), nn.GELU(), Linear(in_features=n_mlp, out_features=dims))
|
376 |
+
self.mlp_ln = LayerNorm(normalized_shape=dims)
|
377 |
+
|
378 |
+
def forward(self, x, mask=None, kv_cache=None):
|
379 |
+
x = self._attn_forward(x=x, mask=mask, kv_cache=kv_cache)
|
380 |
+
x = self._mlp_forward(x=x)
|
381 |
+
return x
|
382 |
+
|
383 |
+
def _attn_forward(self, x, mask=None, kv_cache=None):
|
384 |
+
residual = x
|
385 |
+
x = self.attn_ln(x)
|
386 |
+
|
387 |
+
if isinstance(self.attn, HybridAttention):
|
388 |
+
attn_output = self.attn(x)
|
389 |
+
|
390 |
+
x = residual + attn_output
|
391 |
+
else:
|
392 |
+
attn_output, _ = self.attn(x, mask=mask, kv_cache=kv_cache)
|
393 |
+
x = residual + attn_output
|
394 |
+
return x
|
395 |
+
|
396 |
+
def _mlp_forward(self, x):
|
397 |
+
residual = x
|
398 |
+
x = self.mlp_ln(x)
|
399 |
+
return residual + self.mlp(x)
|
400 |
+
|
401 |
+
class AudioEncoder(nn.Module):
|
402 |
+
def __init__(self, base, mels, dims, head, n_layer, n_ctx, max_dist,
|
403 |
+
win_size, max_span, hybrid, checkpoint, cross, sharpen):
|
404 |
+
super().__init__()
|
405 |
+
self.conv1 = Conv1d(in_channels=mels, out_channels=dims, kernel_size=3, padding=1)
|
406 |
+
self.conv2 = Conv1d(in_channels=dims, out_channels=dims, kernel_size=3, stride=2, padding=1)
|
407 |
+
self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)
|
408 |
+
self.checkpoint = checkpoint
|
409 |
+
|
410 |
+
self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
|
411 |
+
# self.combine = CombinedPositionalEmbedding(base=base, dims=dims, head=head)
|
412 |
+
self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])
|
413 |
+
self.ln_post = LayerNorm(normalized_shape=dims)
|
414 |
+
|
415 |
+
def forward(self, x):
|
416 |
+
if self.checkpoint:
|
417 |
+
x = checkpoint(self._conv_forward, x)
|
418 |
+
else:
|
419 |
+
x = self._conv_forward(x)
|
420 |
+
|
421 |
+
for block in self.blocks:
|
422 |
+
if self.checkpoint:
|
423 |
+
x = checkpoint(block, x)
|
424 |
+
else:
|
425 |
+
x = block(x)
|
426 |
+
return self.ln_post(x)
|
427 |
+
|
428 |
+
def _conv_forward(self, x):
|
429 |
+
x = F.gelu(self.conv1(x))
|
430 |
+
x = F.gelu(self.conv2(x))
|
431 |
+
x = x.permute(0, 2, 1)
|
432 |
+
|
433 |
+
p = self.pos_embed(torch.arange(end=x.size(dim=1), device=x.device)).unsqueeze(0)
|
434 |
+
x = (x + p).to(x.dtype)
|
435 |
+
x = self.givens_rotary(x)
|
436 |
+
# x = self.combine(x)
|
437 |
+
return x
|
438 |
+
|
439 |
+
class TextDecoder(nn.Module):
|
440 |
+
def __init__(self, base, vocab, dims, head, n_layer, n_ctx, max_dist,
|
441 |
+
win_size, max_span, hybrid, checkpoint, cross, sharpen):
|
442 |
+
super().__init__()
|
443 |
+
|
444 |
+
self.tok_embed = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
|
445 |
+
self.pos_embed = SinusoidalEmbedding(n_ctx=n_ctx, dims=dims, checkpoint=checkpoint)
|
446 |
+
self.checkpoint = checkpoint
|
447 |
+
|
448 |
+
self.givens_rotary = CombinedRotaryEmbedding(base=base, dims=dims, head=head)
|
449 |
+
|
450 |
+
self.blocks = nn.ModuleList(modules=[ResidualAttention(base=base, dims=dims, head=head, max_dist=max_dist, win_size=win_size, max_span=max_span, hybrid=hybrid, checkpoint=checkpoint, cross=cross, sharpen=sharpen) for _ in range(n_layer)])
|
451 |
+
|
452 |
+
self.ln_post = LayerNorm(normalized_shape=dims)
|
453 |
+
self.ln = LayerNorm(normalized_shape=dims)
|
454 |
+
|
455 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(value=-np.inf).triu_(diagonal=1)
|
456 |
+
self.register_buffer(name="mask", tensor=mask, persistent=False)
|
457 |
+
self.mask=mask
|
458 |
+
|
459 |
+
def forward(self, x, xa, kv_cache=None):
|
460 |
+
if self.checkpoint:
|
461 |
+
x = checkpoint(self._embedding_forward, x, xa, kv_cache)
|
462 |
+
else:
|
463 |
+
x = self._embedding_forward(x=x, xa=xa, kv_cache=kv_cache)
|
464 |
+
|
465 |
+
for block in self.blocks:
|
466 |
+
if self.checkpoint:
|
467 |
+
x = checkpoint(block, x, self.mask, kv_cache)
|
468 |
+
else:
|
469 |
+
x = block(x, self.mask, kv_cache)
|
470 |
+
|
471 |
+
x = self.ln(x)
|
472 |
+
x = (x @ torch.transpose(input=self.tok_embed.weight.to(dtype=x.dtype), dim0=0, dim1=1)).float()
|
473 |
+
return x
|
474 |
+
|
475 |
+
def _embedding_forward(self, x, xa, kv_cache):
|
476 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
477 |
+
positions = torch.arange(x.shape[1], device=x.device) + offset
|
478 |
+
pos_emb = self.pos_embed(positions).unsqueeze(0)
|
479 |
+
x = self.tok_embed(x) + pos_emb
|
480 |
+
x = self.givens_rotary(x)
|
481 |
+
return x
|
482 |
+
|
483 |
+
class EchoConfig(PretrainedConfig):
|
484 |
+
model_type = "Echo"
|
485 |
+
def __init__(
|
486 |
+
self,
|
487 |
+
checkpoint=False,
|
488 |
+
cross=False,
|
489 |
+
hybrid=False,
|
490 |
+
sharpen=False,
|
491 |
+
a_ctx=1500,
|
492 |
+
a_head=16,
|
493 |
+
a_layer=8,
|
494 |
+
a_dims=1024,
|
495 |
+
mels=128,
|
496 |
+
t_ctx=448,
|
497 |
+
t_head=8,
|
498 |
+
t_layer=8,
|
499 |
+
t_dims=1024,
|
500 |
+
win_size=64,
|
501 |
+
max_span=64,
|
502 |
+
max_dist=64,
|
503 |
+
base=10000,
|
504 |
+
pad_token_id=50257,
|
505 |
+
unk_token_id=50257,
|
506 |
+
vocab=51865,
|
507 |
+
eos_token_id=50257,
|
508 |
+
bos_token_id=50257,
|
509 |
+
decoder_start_token_id=50258,
|
510 |
+
**kwargs,
|
511 |
+
):
|
512 |
+
|
513 |
+
super().__init__(**kwargs)
|
514 |
+
self.base = base
|
515 |
+
self.bos_token_id = bos_token_id
|
516 |
+
self.checkpoint = checkpoint
|
517 |
+
self.cross = cross
|
518 |
+
self.decoder_start_token_id = decoder_start_token_id
|
519 |
+
self.eos_token_id = eos_token_id
|
520 |
+
self.hybrid = hybrid
|
521 |
+
self.max_dist = max_dist
|
522 |
+
self.max_span = max_span
|
523 |
+
self.a_ctx = a_ctx
|
524 |
+
self.a_head = a_head
|
525 |
+
self.a_layer = a_layer
|
526 |
+
self.a_dims = a_dims
|
527 |
+
self.mels = mels
|
528 |
+
self.t_ctx = t_ctx
|
529 |
+
self.t_head = t_head
|
530 |
+
self.t_layer = t_layer
|
531 |
+
self.t_dims = t_dims
|
532 |
+
self.pad_token_id = pad_token_id
|
533 |
+
self.unk_token_id = unk_token_id
|
534 |
+
self.vocab = vocab
|
535 |
+
self.win_size = win_size
|
536 |
+
self.sharpen=sharpen
|
537 |
+
|
538 |
+
class Echo(nn.Module):
|
539 |
+
def __init__(self, config: EchoConfig):
|
540 |
+
super().__init__()
|
541 |
+
self.config = config
|
542 |
+
|
543 |
+
self.encoder = AudioEncoder(
|
544 |
+
base=self.config.base,
|
545 |
+
mels=self.config.mels,
|
546 |
+
dims=self.config.a_dims,
|
547 |
+
head=self.config.a_head,
|
548 |
+
n_layer=self.config.a_layer,
|
549 |
+
n_ctx=self.config.a_ctx,
|
550 |
+
max_dist=self.config.max_dist,
|
551 |
+
win_size=self.config.win_size,
|
552 |
+
max_span=self.config.max_span,
|
553 |
+
hybrid=self.config.hybrid,
|
554 |
+
checkpoint=self.config.checkpoint,
|
555 |
+
cross=self.config.cross,
|
556 |
+
sharpen=self.config.sharpen,
|
557 |
+
)
|
558 |
+
|
559 |
+
self.decoder = TextDecoder(
|
560 |
+
base=self.config.base,
|
561 |
+
vocab=self.config.vocab,
|
562 |
+
dims=self.config.t_dims,
|
563 |
+
head=self.config.t_head,
|
564 |
+
n_layer=self.config.t_layer,
|
565 |
+
n_ctx=self.config.t_ctx,
|
566 |
+
max_dist=self.config.max_dist,
|
567 |
+
win_size=self.config.win_size,
|
568 |
+
max_span=self.config.max_span,
|
569 |
+
hybrid=self.config.hybrid,
|
570 |
+
checkpoint=self.config.checkpoint,
|
571 |
+
cross=self.config.cross,
|
572 |
+
sharpen=self.config.sharpen,
|
573 |
+
)
|
574 |
+
|
575 |
+
all_heads = torch.zeros(self.config.t_layer, self.config.t_head, dtype=torch.bool)
|
576 |
+
all_heads[self.config.t_layer // 2:] = True
|
577 |
+
self.register_buffer(name="alignment_heads", tensor=all_heads.to_sparse(), persistent=False)
|
578 |
+
|
579 |
+
self.base = self.config.base
|
580 |
+
self.win_size = self.config.win_size
|
581 |
+
self.adjust_counter = 0
|
582 |
+
self.best_loss = float('inf')
|
583 |
+
self.kv_cache = {}
|
584 |
+
|
585 |
+
@property
|
586 |
+
def device(self):
|
587 |
+
return next(self.parameters()).device
|
588 |
+
|
589 |
+
def embed_audio(self, mel: torch.Tensor):
|
590 |
+
return self.encoder(mel)
|
591 |
+
|
592 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
593 |
+
return self.decoder(tokens, audio_features)
|
594 |
+
|
595 |
+
def update_window(self, new_window):
|
596 |
+
self.win_size = new_window
|
597 |
+
for module in self.modules():
|
598 |
+
if isinstance(module, HybridAttention):
|
599 |
+
module.update_window(self.win_size)
|
600 |
+
|
601 |
+
def adjust_window(self, loss, factor=1.00005):
|
602 |
+
if self.adjust_counter % 10 == 0:
|
603 |
+
if loss < self.best_loss:
|
604 |
+
new_window = self.win_size * factor
|
605 |
+
else:
|
606 |
+
new_window = self.win_size / factor
|
607 |
+
self.update_window(new_window=new_window)
|
608 |
+
self.best_loss = loss
|
609 |
+
self.adjust_counter += 1
|
610 |
+
return new_window
|
611 |
+
return self.win_size
|
612 |
+
|
613 |
+
def adjust_base(self, loss, factor=1.0025) -> float | int:
|
614 |
+
if self.adjust_counter % 25 == 0:
|
615 |
+
if loss < self.best_loss:
|
616 |
+
new_base=self.base*factor
|
617 |
+
else:
|
618 |
+
new_base=self.base/factor
|
619 |
+
self.update_base(new_base=new_base)
|
620 |
+
self.base=new_base
|
621 |
+
self.best_loss=loss
|
622 |
+
self.adjust_counter += 1
|
623 |
+
return self.base
|
624 |
+
|
625 |
+
def update_base(self, new_base):
|
626 |
+
self.new_base=new_base
|
627 |
+
for name, module in self.encoder.named_modules():
|
628 |
+
if isinstance(module, (CombinedRotaryEmbedding)):
|
629 |
+
module.update_base(new_base=self.new_base)
|
630 |
+
|
631 |
+
@staticmethod
|
632 |
+
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
633 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
634 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
635 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
636 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
637 |
+
return shifted_input_ids
|
638 |
+
|
639 |
+
def forward(self, input_features, labels=None, dec_input_ids=None) -> dict[str, Any | None]:
|
640 |
+
if labels is not None:
|
641 |
+
if dec_input_ids is None:
|
642 |
+
dec_input_ids = self.shift_tokens_right(
|
643 |
+
input_ids=labels, pad_token_id=self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
|
644 |
+
)
|
645 |
+
|
646 |
+
encoded_features = self.encoder(input_features).to(self.device)
|
647 |
+
logits = self.decoder(dec_input_ids, encoded_features)
|
648 |
+
|
649 |
+
loss = None
|
650 |
+
if labels is not None:
|
651 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
652 |
+
labels = labels.to(logits.device).long()
|
653 |
+
loss = loss_fct(logits.view(-1, self.config.vocab), labels.view(-1))
|
654 |
+
|
655 |
+
self.adjust_window(loss.item())
|
656 |
+
# self.adjust_base(loss=loss.item())
|
657 |
+
return {"loss": loss, "logits": logits}
|
658 |
+
|
659 |
+
def reset_parameters(self):
|
660 |
+
for name, module in self.encoder.named_modules():
|
661 |
+
if isinstance(module, CombinedRotaryEmbedding):
|
662 |
+
module.reset_parameters()
|
663 |
+
|
664 |
+
def _initialize_weights(self, module):
|
665 |
+
nn.init.normal_(tensor=self.decoder.tok_embed.weight, mean=0.0, std=0.02)
|
666 |
+
nn.init.constant_(tensor=self.decoder.ln.weight, val=1)
|
667 |
+
nn.init.constant_(tensor=self.decoder.ln.bias, val=0)
|
668 |
+
nn.init.xavier_normal_(tensor=self.encoder.conv1.weight)
|
669 |
+
nn.init.zeros_(tensor=self.encoder.conv1.bias)
|
670 |
+
nn.init.kaiming_normal_(tensor=self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
|
671 |
+
nn.init.zeros_(tensor=self.encoder.conv2.bias)
|
672 |
+
nn.init.constant_(tensor=self.encoder.ln_post.weight, val=1)
|
673 |
+
nn.init.constant_(tensor=self.encoder.ln_post.bias, val=0)
|
674 |
+
|
675 |
+
for block in self.decoder.blocks:
|
676 |
+
for layer in block.children():
|
677 |
+
if isinstance(layer, nn.Linear):
|
678 |
+
nn.init.xavier_normal_(tensor=layer.weight)
|
679 |
+
nn.init.zeros_(tensor=layer.bias)
|
680 |
+
if isinstance(layer, LayerNorm):
|
681 |
+
nn.init.constant_(tensor=layer.weight, val=1)
|
682 |
+
|
683 |
+
for block in self.encoder.blocks:
|
684 |
+
for layer in block.children():
|
685 |
+
if isinstance(layer, nn.Linear):
|
686 |
+
nn.init.xavier_normal_(tensor=layer.weight)
|
687 |
+
nn.init.zeros_(tensor=layer.bias)
|
688 |
+
if isinstance(layer, LayerNorm):
|
689 |
+
nn.init.constant_(tensor=layer.weight, val=1)
|
690 |
+
|
691 |
+
for module in self.encoder.named_modules():
|
692 |
+
if isinstance(module, CombinedRotaryEmbedding):
|
693 |
+
nn.init.constant_(tensor=module.thetas, val=1)
|
694 |
+
nn.init.constant_(tensor=module.r_matrix, val=1)
|
695 |
+
nn.init.constant_(tensor=module.r_pairs, val=1)
|
696 |
+
nn.init.constant_(tensor=module.inv_freq, val=1)
|
697 |
+
|
698 |
+
def apply_initialization(self, module):
|
699 |
+
self._initialize_weights(module=module)
|
700 |
+
|
701 |
+
from datetime import datetime
|
702 |
+
log_dir = os.path.join('./output/Echo/', datetime.now().strftime(format='%m-%d_%H'))
|
703 |
+
os.makedirs(name=log_dir, exist_ok=True)
|
704 |
+
|
705 |
+
config = EchoConfig(
|
706 |
+
checkpoint=False,
|
707 |
+
cross=False,
|
708 |
+
hybrid=False,
|
709 |
+
sharpen=False,
|
710 |
+
audio_ctx=1500,
|
711 |
+
audio_head=4,
|
712 |
+
audio_layer=4,
|
713 |
+
audio_dims=512,
|
714 |
+
mels=128,
|
715 |
+
text_ctx=448,
|
716 |
+
text_head=4,
|
717 |
+
text_layer=4,
|
718 |
+
text_dims=512,
|
719 |
+
win_size=16,
|
720 |
+
max_span=16,
|
721 |
+
max_dist=16,
|
722 |
+
base=50000,
|
723 |
+
pad_token_id=50257,
|
724 |
+
unk_token_id=50257,
|
725 |
+
vocab=51865,
|
726 |
+
eos_token_id=50257,
|
727 |
+
bos_token_id=50257,
|
728 |
+
decoder_start_token_id=50258,
|
729 |
+
|
730 |
+
)
|
731 |
+
|
732 |
+
model = Echo(config=config).to(device=device)
|
733 |
+
model.apply_initialization(module=model)
|
734 |
+
|
735 |
+
|
736 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
737 |
+
pretrained_model_name_or_path="openai/whisper-small",
|
738 |
+
feature_size=128, sample_rate=160000, do_normalize=True)
|
739 |
+
|
740 |
+
tokenizer = WhisperTokenizerFast.from_pretrained(
|
741 |
+
pretrained_model_name_or_path="openai/whisper-small",
|
742 |
+
language="en", task="transcribe")
|
743 |
+
|
744 |
+
processor = WhisperProcessor.from_pretrained(
|
745 |
+
pretrained_model_name_or_path="openai/whisper-small",
|
746 |
+
feature_size=128, sample_rate=160000, do_normalize=True,
|
747 |
+
language="en", task="transcribe")
|
748 |
+
|
749 |
+
class GradientClippingCallback(TrainerCallback):
|
750 |
+
def on_step_end(self, args, dims, control, **kwargs):
|
751 |
+
torch.nn.utils.clip_grad_norm_(parameters=kwargs["model"].parameters(), max_norm=0.98)
|
752 |
+
|
753 |
+
@dataclass
|
754 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
755 |
+
processor: Any
|
756 |
+
decoder_start_token_id: int
|
757 |
+
|
758 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
759 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
760 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
761 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
762 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
763 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
764 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
765 |
+
labels = labels[:, 1:]
|
766 |
+
batch["labels"] = labels
|
767 |
+
return batch
|
768 |
+
|
769 |
+
def get_length_of_dataset(dataset):
|
770 |
+
length = 0
|
771 |
+
for item in dataset:
|
772 |
+
length += len(item["audio"]["array"]) / item["audio"]["sampling_rate"]
|
773 |
+
return length / 3600
|
774 |
+
|
775 |
+
def prepare_dataset(batch):
|
776 |
+
audio = batch["audio"]
|
777 |
+
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
778 |
+
batch["labels"] = tokenizer(batch["sentence"]).input_ids
|
779 |
+
return batch
|
780 |
+
|
781 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=config.decoder_start_token_id)
|
782 |
+
|
783 |
+
datasets = IterableDatasetDict()
|
784 |
+
|
785 |
+
datasets["train"] = load_dataset(
|
786 |
+
path="mozilla-foundation/common_voice_17_0", token="",
|
787 |
+
name="en", split="train", streaming=True, trust_remote_code=True).take(10000)
|
788 |
+
|
789 |
+
datasets["test"] = load_dataset(
|
790 |
+
path="mozilla-foundation/common_voice_17_0", token="",
|
791 |
+
name="en", split="test", streaming=True, trust_remote_code=True).take(100)
|
792 |
+
|
793 |
+
dataset = datasets.cast_column(column="audio", feature=Audio(sampling_rate=16000))
|
794 |
+
|
795 |
+
dataset = dataset.map(function=prepare_dataset,
|
796 |
+
remove_columns=list(next(iter(dataset.values())).features)).with_format(type="torch")
|
797 |
+
|
798 |
+
class MetricsCallback(TrainerCallback):
|
799 |
+
def __init__(self, tb_writer, tokenizer, metric, optimizer, scheduler, log_every_n_steps=1):
|
800 |
+
super().__init__()
|
801 |
+
self.tb_writer = tb_writer
|
802 |
+
self.tokenizer = tokenizer
|
803 |
+
self.metric = metric
|
804 |
+
self.optimizer = optimizer
|
805 |
+
self.scheduler = scheduler
|
806 |
+
self.log_every_n_steps = log_every_n_steps
|
807 |
+
self.predictions = None
|
808 |
+
self.label_ids = None
|
809 |
+
|
810 |
+
def compute_wer(self, pred_str, label_str):
|
811 |
+
wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)
|
812 |
+
return wer
|
813 |
+
|
814 |
+
def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
|
815 |
+
if metrics is not None:
|
816 |
+
self.eval_loss = metrics.get('eval_loss')
|
817 |
+
|
818 |
+
current_learning_rate = self.optimizer.param_groups[0]['lr']
|
819 |
+
if state.global_step % self.log_every_n_steps == 0:
|
820 |
+
self.tb_writer.add_scalar('learning_rate', current_learning_rate, state.global_step)
|
821 |
+
print(f"Learning Rate: {current_learning_rate:.8f}")
|
822 |
+
|
823 |
+
self.tb_writer.add_scalar('eval_loss', self.eval_loss, state.global_step)
|
824 |
+
|
825 |
+
for key, value in metrics.items():
|
826 |
+
if key.startswith("eval_"):
|
827 |
+
self.tb_writer.add_scalar(key, value, state.global_step)
|
828 |
+
|
829 |
+
if self.predictions is not None and self.label_ids is not None:
|
830 |
+
pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)
|
831 |
+
label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)
|
832 |
+
|
833 |
+
if state.global_step % self.log_every_n_steps == 0:
|
834 |
+
total_samples = len(pred_str)
|
835 |
+
random_indices = random.sample(range(total_samples), 1)
|
836 |
+
|
837 |
+
for sample_index in random_indices:
|
838 |
+
self.tb_writer.add_text(f"Prediction_{sample_index}", pred_str[sample_index], state.global_step)
|
839 |
+
self.tb_writer.add_text(f"Label_{sample_index}", label_str[sample_index], state.global_step)
|
840 |
+
print(f"Evaluation: - Step {state.global_step} - Loss: {self.eval_loss:.2f}")
|
841 |
+
print(f"Prediction: {pred_str[sample_index]}")
|
842 |
+
print(f"Label: {label_str[sample_index]}")
|
843 |
+
print("-" * 10)
|
844 |
+
|
845 |
+
self.predictions = None
|
846 |
+
self.label_ids = None
|
847 |
+
|
848 |
+
def create_compute_metrics(callback_instance):
|
849 |
+
def compute_metrics(eval_pred):
|
850 |
+
pred_logits = eval_pred.predictions
|
851 |
+
label_ids = eval_pred.label_ids
|
852 |
+
|
853 |
+
if isinstance(pred_logits, tuple):
|
854 |
+
pred_ids = pred_logits[0]
|
855 |
+
else:
|
856 |
+
pred_ids = pred_logits
|
857 |
+
if pred_ids.ndim == 3:
|
858 |
+
pred_ids = np.argmax(pred_ids, axis=-1)
|
859 |
+
|
860 |
+
label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id
|
861 |
+
callback_instance.predictions = pred_ids
|
862 |
+
callback_instance.label_ids = label_ids
|
863 |
+
pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
864 |
+
label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
865 |
+
wer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)
|
866 |
+
pred_flat = pred_ids.flatten()
|
867 |
+
labels_flat = label_ids.flatten()
|
868 |
+
mask = labels_flat != callback_instance.tokenizer.pad_token_id
|
869 |
+
|
870 |
+
accuracy = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])
|
871 |
+
precision = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
|
872 |
+
recall = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
|
873 |
+
f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
|
874 |
+
return {"wer": wer, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
|
875 |
+
return compute_metrics
|
876 |
+
|
877 |
+
metric = evaluate.load(path="wer")
|
878 |
+
tb_writer = SummaryWriter(log_dir=log_dir)
|
879 |
+
|
880 |
+
training_args = Seq2SeqTrainingArguments(
|
881 |
+
output_dir=log_dir,
|
882 |
+
per_device_train_batch_size=1,
|
883 |
+
per_device_eval_batch_size=1,
|
884 |
+
gradient_accumulation_steps=1,
|
885 |
+
eval_accumulation_steps=1,
|
886 |
+
tf32=True,
|
887 |
+
bf16=True,
|
888 |
+
eval_strategy="steps",
|
889 |
+
save_strategy="steps",
|
890 |
+
max_steps=10000,
|
891 |
+
save_steps=10000,
|
892 |
+
eval_steps=100,
|
893 |
+
warmup_steps=100,
|
894 |
+
logging_steps=10,
|
895 |
+
logging_dir=log_dir + "/logs_hf",
|
896 |
+
report_to=["tensorboard"],
|
897 |
+
load_best_model_at_end=False,
|
898 |
+
metric_for_best_model="loss",
|
899 |
+
greater_is_better=False,
|
900 |
+
push_to_hub=False,
|
901 |
+
disable_tqdm=False,
|
902 |
+
save_total_limit=1,
|
903 |
+
remove_unused_columns=False,
|
904 |
+
label_names=["labels"],
|
905 |
+
eval_on_start=True,
|
906 |
+
)
|
907 |
+
|
908 |
+
class MaxFactor(Optimizer):
|
909 |
+
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(None, 1e-3), d=1.0,
|
910 |
+
weight_decay=0.0, gamma=0.99, eps_rms=1e-8, maximize=False):
|
911 |
+
|
912 |
+
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
|
913 |
+
gamma=gamma, eps_rms=eps_rms, maximize=maximize)
|
914 |
+
|
915 |
+
super().__init__(params, defaults)
|
916 |
+
|
917 |
+
@torch.no_grad()
|
918 |
+
def step(self, closure=None):
|
919 |
+
loss = None
|
920 |
+
if closure is not None:
|
921 |
+
with torch.enable_grad():
|
922 |
+
loss = closure()
|
923 |
+
|
924 |
+
for group in self.param_groups:
|
925 |
+
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
|
926 |
+
eps1, eps2 = group["eps"]
|
927 |
+
for p in group["params"]:
|
928 |
+
if p.grad is None:
|
929 |
+
continue
|
930 |
+
grad = p.grad
|
931 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
932 |
+
grad = grad.float()
|
933 |
+
|
934 |
+
state = self.state[p]
|
935 |
+
if len(state) == 0:
|
936 |
+
state["step"] = torch.tensor(0.0, dtype=torch.float32)
|
937 |
+
if p.grad.dim() > 1:
|
938 |
+
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
|
939 |
+
row_shape[-1], col_shape[-2] = 1, 1
|
940 |
+
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
|
941 |
+
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
942 |
+
|
943 |
+
row_vars.append(state.get("row_var", None))
|
944 |
+
col_vars.append(state.get("col_var", None))
|
945 |
+
v.append(state["v"])
|
946 |
+
state_steps.append(state["step"])
|
947 |
+
params_with_grad.append(p)
|
948 |
+
grads.append(grad)
|
949 |
+
|
950 |
+
for i, param in enumerate(params_with_grad):
|
951 |
+
grad = grads[i]
|
952 |
+
|
953 |
+
if group["maximize"]:
|
954 |
+
grad = -grad
|
955 |
+
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
|
956 |
+
|
957 |
+
if eps1 is None:
|
958 |
+
eps1 = torch.finfo(param.dtype).eps
|
959 |
+
|
960 |
+
step_t += 1
|
961 |
+
step_float = step_t.item()
|
962 |
+
one_minus_beta2_t = step_float ** group["beta2_decay"]
|
963 |
+
rho_t = min(group["lr"], 1 / (step_float ** 0.5))
|
964 |
+
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
|
965 |
+
|
966 |
+
if group["weight_decay"]!= 0:
|
967 |
+
param.mul_(1 - group["lr"] * group["weight_decay"])
|
968 |
+
|
969 |
+
if grad.dim() > 1:
|
970 |
+
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
|
971 |
+
row_var.lerp_(row_mean, one_minus_beta2_t)
|
972 |
+
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
|
973 |
+
col_var.lerp_(col_mean, one_minus_beta2_t)
|
974 |
+
var_estimate = row_var @ col_var
|
975 |
+
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
|
976 |
+
var_estimate.div_(max_row_var.clamp_(min=eps1))
|
977 |
+
|
978 |
+
else:
|
979 |
+
vi.mul_(group["gamma"]).add_(1 - group["gamma"], grad ** 2)
|
980 |
+
var_estimate = vi
|
981 |
+
|
982 |
+
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
|
983 |
+
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
|
984 |
+
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
|
985 |
+
param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
|
986 |
+
|
987 |
+
return loss
|
988 |
+
|
989 |
+
optimizer = MaxFactor(
|
990 |
+
model.parameters(),
|
991 |
+
lr=0.025,
|
992 |
+
beta2_decay=-0.8,
|
993 |
+
eps=(None, 1e-4),
|
994 |
+
d=1.0,
|
995 |
+
weight_decay=0.0025,
|
996 |
+
gamma=0.99,
|
997 |
+
eps_rms=1e-8,
|
998 |
+
maximize=False,
|
999 |
+
)
|
1000 |
+
|
1001 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
1002 |
+
optimizer=optimizer,
|
1003 |
+
T_max=training_args.max_steps,
|
1004 |
+
eta_min=0.0,
|
1005 |
+
last_epoch=-1
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
metrics_callback = MetricsCallback(tb_writer=tb_writer, tokenizer=tokenizer, metric=metric, optimizer=optimizer, scheduler=scheduler, log_every_n_steps=10)
|
1009 |
+
compute_metrics = create_compute_metrics(callback_instance=metrics_callback)
|
1010 |
+
|
1011 |
+
trainer = Seq2SeqTrainer(
|
1012 |
+
args=training_args,
|
1013 |
+
model=model,
|
1014 |
+
train_dataset=dataset["train"],
|
1015 |
+
eval_dataset=dataset["test"],
|
1016 |
+
data_collator=data_collator,
|
1017 |
+
compute_metrics=compute_metrics,
|
1018 |
+
processing_class=feature_extractor,
|
1019 |
+
callbacks=[metrics_callback],
|
1020 |
+
optimizers=(optimizer, scheduler)
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
trainer.train(resume_from_checkpoint=False)
|
1024 |
+
|
1025 |
+
from tensorboard import program
|
1026 |
+
log_dir = "D:/new/tensorboard3"
|
1027 |
+
tb = program.TensorBoard()
|
1028 |
+
tb.configure(argv=[None, '--logdir', log_dir])
|
1029 |
+
url = tb.launch()
|
1030 |
+
print(f"TensorBoard started at {url}")
|
1031 |
+
|
1032 |
+
|