{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import base64, os, evaluate, random, gzip, math, torch, numpy as np, json, warnings\n", "from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score\n", "from datasets import load_dataset, IterableDatasetDict, Audio\n", "from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperProcessor,WhisperFeatureExtractor,\n", "WhisperTokenizerFast)\n", "import torch.nn.functional as F\n", "import transformers\n", "from itertools import chain\n", "from torch.utils.checkpoint import checkpoint\n", "from typing import Dict, Optional, Tuple\n", "from torch import Tensor, nn\n", "from dataclasses import dataclass\n", "from typing import Dict, Optional, Tuple, Union, List, Any\n", "from torch.nn.functional import scaled_dot_product_attention\n", "\n", "torch.backends.cudnn.allow_tf32 = True\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "transformers.utils.logging.set_verbosity_error()\n", "device = torch.device(device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "dtype = torch.float32\n", "torch.set_default_dtype(dtype)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class Dimensions:\n", " mels: int\n", " audio_ctx: int\n", " audio_state: int\n", " audio_head: int\n", " audio_layerA: int\n", " audio_layerB: int\n", " vocab: int\n", " text_ctx: int\n", " text_state: int\n", " text_head: int\n", " text_layerA: int\n", " text_layerB: int\n", " dropout: float\n", " activation: str\n", " checkpoint: bool\n", "\n", "\n", "class LayerNorm(nn.LayerNorm):\n", " def forward(self, x: Tensor) -> Tensor:\n", " return super().forward(input=x.float()).type(dtype=x.dtype)\n", "\n", "\n", "class Linear(nn.Linear):\n", " def forward(self, x: Tensor) -> Tensor:\n", " return F.linear(\n", " input=x,\n", " weight=self.weight.to(dtype=x.dtype),\n", " bias=None if self.bias is None else self.bias.to(dtype=x.dtype),\n", " )\n", "\n", "\n", "class Conv1d(nn.Conv1d):\n", " def _conv_forward(\n", " self, x: Tensor, weight: Tensor, bias: Optional[Tensor]\n", " ) -> Tensor:\n", " return super()._conv_forward(\n", " input=x,\n", " weight=weight.to(dtype=x.dtype),\n", " bias=None if bias is None else bias.to(dtype=x.dtype),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @torch.jit.script\n", "# def _apply_qrotation(x: torch.Tensor, theta: torch.Tensor, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n", "# u = u / torch.norm(u)\n", "# v = v / torch.norm(v)\n", "\n", "# half_theta = theta / 2\n", "# cos_ht = torch.cos(half_theta)\n", "# sin_ht = torch.sin(half_theta)\n", "\n", "# q = torch.cat([cos_ht.unsqueeze(0), sin_ht * u])\n", " \n", "# x_shape = x.shape\n", "# x = x.view(-1, 3)\n", "\n", "# uv_cross = torch.cross(u.unsqueeze(0), x)\n", "# uuv_cross = torch.cross(u.unsqueeze(0), uv_cross)\n", "# x_rot = x + 2 * (q[0] * uv_cross + uuv_cross)\n", "\n", "# x_rot = x_rot.view(*x_shape)\n", "# return x_rot\n", "\n", "# @torch.jit.script\n", "# def _create_rotation_matrix(dims: int, i: int, j: int, theta: torch.Tensor, device: torch.device) -> torch.Tensor:\n", "# G = torch.eye(dims, device=device)\n", "# c, s = torch.cos(theta), torch.sin(theta)\n", "# G[i, i], G[j, j] = c, c\n", "# G[i, j], G[j, i] = -s, s\n", " \n", "# if dims == 3:\n", "# u = torch.eye(dims, device=device)[i]\n", "# v = torch.eye(dims, device=device)[j]\n", "# x = torch.eye(dims, device=device)\n", " \n", "# Q = _apply_qrotation(x, theta=theta, u=u, v=v)\n", "# G = (G + Q) / 2\n", "# return G\n", "\n", "# @torch.jit.script\n", "# def _apply_rope_transform(\n", "# x: torch.Tensor, \n", "# sin: torch.Tensor, \n", "# cos: torch.Tensor\n", "# ) -> torch.Tensor:\n", "# x1, x2 = x[..., ::2], x[..., 1::2]\n", "# return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n", "\n", "\n", "# class rotary2(nn.Module):\n", "# def __init__(self, ctx, dims, heads, base=10000, theta_learnable=False,\n", "# rot_learnable=False, matrix_learnable=False, freq_learnable=False,\n", "# ):\n", "# super().__init__()\n", "# self.ctx = ctx\n", "# self.dims = dims\n", "# self.heads = heads\n", "# self.base = base\n", "\n", "# self.head_dim = self.dims // self.heads\n", "# self.rot = self.head_dim // 2\n", "\n", "# self.thetas = nn.Parameter(torch.zeros(self.rot))\n", "# self.r_pairs = nn.Parameter(torch.rand(self.rot, 2) * self.head_dim)\n", "# self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)\n", "# self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)\n", "# self.r_matrix = nn.Parameter(\n", "# torch.eye(self.head_dim), requires_grad=matrix_learnable\n", "# )\n", "\n", "# freq_data = 1.0 / (\n", "# self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)\n", "# )\n", "\n", "# self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)\n", "\n", "# self.reset_parameters()\n", "\n", "# def reset_parameters(self):\n", "# nn.init.orthogonal_(self.r_matrix)\n", "# nn.init.zeros_(self.thetas)\n", "\n", "# def q_rotation(self, x, theta, u, v):\n", "# return _apply_qrotation(x, theta, u, v)\n", "\n", "# def rotation_matrix(self, dims, i, j, theta):\n", "# return _create_rotation_matrix(dims, i, j, theta, theta.device)\n", "\n", "# @torch.jit.script_method # type: ignore\n", "# def apply_rotations(self, x: torch.Tensor) -> torch.Tensor:\n", "# adjusted_rot = int(self.rot_scale.item() * self.rot)\n", "# for k in range(adjusted_rot):\n", "# i, j = int(self.r_pairs[k, 0].item()), int(self.r_pairs[k, 1].item())\n", "# theta = self.thetas[k] * self.theta_scale\n", "# G = _create_rotation_matrix(self.head_dim, i, j, theta, x.device)\n", "# x = x @ G\n", "# return x\n", "\n", "# def forward(self, x: torch.Tensor) -> torch.Tensor:\n", "# batch_size, seq_len = x.shape[0], x.shape[1]\n", " \n", "# if x.dim() == 3:\n", "# if x.shape[2] != self.dims:\n", "# raise ValueError(f\"Expected dim {self.dims}, got {x.shape[2]}\")\n", "# x = x.view(batch_size, seq_len, self.heads, self.head_dim)\n", "# elif x.dim() == 4:\n", "# if x.shape[2] != self.heads or x.shape[3] != self.head_dim:\n", "# raise ValueError(f\"Expected {self.heads} heads and {self.head_dim} head_dim\")\n", "# else:\n", "# raise ValueError(f\"Expected 3D or 4D input, got {x.dim()}D\")\n", "\n", "# x_flat = x.reshape(-1, self.head_dim)\n", "# x_rotated = self.apply_rotations(x_flat)\n", "# x_rotated = x_rotated @ self.r_matrix\n", " \n", "# x = x_rotated.view(batch_size, seq_len, self.heads, self.head_dim)\n", " \n", "# position = torch.arange(seq_len, device=x.device, dtype=x.dtype).unsqueeze(1)\n", "# div_term = self.inv_freq.unsqueeze(0)\n", "# sinusoid_inp = position * div_term\n", " \n", "# sin = torch.sin(sinusoid_inp).unsqueeze(0).unsqueeze(2)\n", "# cos = torch.cos(sinusoid_inp).unsqueeze(0).unsqueeze(2)\n", " \n", "# x = _apply_rope_transform(x, sin, cos)\n", " \n", "# x = x.view(batch_size, seq_len, self.dims)\n", "# x = x * math.sqrt(self.dims)\n", "# return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class rotary(nn.Module):\n", " def __init__(self, ctx, dims, heads, base=10000, theta_learnable=False,\n", " rot_learnable=False, matrix_learnable=False, freq_learnable=False,\n", " ):\n", " super().__init__()\n", " self.ctx = ctx\n", " self.dims = dims\n", " self.heads = heads\n", " self.base = base\n", "\n", " self.head_dim = self.dims // self.heads\n", " self.rot = self.head_dim // 2\n", "\n", " self.thetas = nn.Parameter(torch.zeros(self.rot))\n", " self.r_pairs = nn.Parameter(torch.rand(self.rot, 2) * self.head_dim)\n", " self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)\n", " self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)\n", " self.r_matrix = nn.Parameter(\n", " torch.eye(self.head_dim), requires_grad=matrix_learnable\n", " )\n", "\n", " freq_data = 1.0 / (\n", " self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)\n", " )\n", "\n", " self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)\n", "\n", " self.reset_parameters()\n", "\n", " def reset_parameters(self):\n", " nn.init.orthogonal_(self.r_matrix)\n", " nn.init.zeros_(self.thetas)\n", "\n", " def q_rotation(self, x, theta, u, v):\n", " u = u / torch.norm(u)\n", " v = v / torch.norm(v)\n", "\n", " half_theta = theta / 2\n", " cos_ht = torch.cos(half_theta)\n", " sin_ht = torch.sin(half_theta)\n", "\n", " q = torch.cat([cos_ht.unsqueeze(0), sin_ht * u])\n", " q_conj = torch.cat([cos_ht.unsqueeze(0), -sin_ht * u])\n", "\n", " x_shape = x.shape\n", " x = x.view(-1, 3)\n", "\n", " uv_cross = torch.cross(u.unsqueeze(0), x)\n", " uuv_cross = torch.cross(u.unsqueeze(0), uv_cross)\n", " x_rot = x + 2 * (q[0] * uv_cross + uuv_cross)\n", "\n", " x_rot = x_rot.view(*x_shape)\n", " return x_rot\n", "\n", " def rotation_matrix(self, dims, i, j, theta):\n", " G = torch.eye(dims, device=theta.device)\n", " c, s = torch.cos(theta), torch.sin(theta)\n", " G[i, i], G[j, j] = c, c\n", " G[i, j], G[j, i] = -s, s\n", "\n", " if dims == 3:\n", " u = torch.eye(dims, device=theta.device)[i]\n", " v = torch.eye(dims, device=theta.device)[j]\n", " Q = self.q_rotation(\n", " torch.eye(dims, device=theta.device), theta=theta, u=u, v=v\n", " )\n", " G = (G + Q) / 2\n", " return G\n", "\n", " def apply_rotations(self, x):\n", " adjusted_rot = int(torch.round(self.rot_scale * self.rot))\n", " for k in range(adjusted_rot):\n", " i, j = self.r_pairs[k].long()\n", " theta = self.thetas[k] * self.theta_scale\n", " G = self.rotation_matrix(self.head_dim, i.item(), j.item(), theta)\n", " x = x @ G\n", " return x\n", "\n", " def forward(self, x):\n", " batch_size, seq_len, *rest = x.size()\n", "\n", " if len(rest) == 1:\n", " dims = rest[0]\n", " if dims != self.heads * self.head_dim:\n", " raise ValueError(\n", " f\"Needed {self.heads * self.head_dim}, but got too many {dims}\"\n", " )\n", " elif len(rest) == 2:\n", " heads, head_dim = rest\n", " if heads != self.heads or head_dim != self.head_dim:\n", " raise ValueError(\n", " f\"This many heads {self.heads} and head_dims {self.head_dim} we need, got this many heads {heads} and head_dims {head_dim} we did.\"\n", " )\n", " else:\n", " raise ValueError(f\"Expected the thingy to be 3D or 4D, but got {x.dim()}D\")\n", "\n", " x = x.view(batch_size, seq_len, self.heads, self.head_dim)\n", " x = x.reshape(-1, self.head_dim)\n", "\n", " x = self.apply_rotations(x)\n", " x = x @ self.r_matrix\n", "\n", " x = x.view(batch_size, seq_len, self.heads, self.head_dim)\n", "\n", " position = torch.arange(seq_len, device=x.device, dtype=x.dtype).unsqueeze(1)\n", " div_term = self.inv_freq.unsqueeze(0)\n", " sinusoid_inp = position * div_term\n", "\n", " sin = torch.sin(sinusoid_inp).unsqueeze(0).unsqueeze(2)\n", " cos = torch.cos(sinusoid_inp).unsqueeze(0).unsqueeze(2)\n", "\n", " x1, x2 = x[..., ::2], x[..., 1::2]\n", " x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n", " x = x.view(batch_size, seq_len, self.dims)\n", " x = x * math.sqrt(self.dims)\n", "\n", " return x\n", "\n", "\n", "class PositionalEncoding(nn.Module):\n", " def __init__(self, dims, ctx):\n", " super(PositionalEncoding, self).__init__()\n", " self.dims = dims\n", " self.ctx = ctx\n", " self.pe = self.get_positional_encoding(max_seq_len=ctx)\n", "\n", " def get_positional_encoding(self, max_seq_len):\n", " pe = torch.zeros(max_seq_len, self.dims)\n", " position = torch.arange(0, max_seq_len, dtype=torch.float32).unsqueeze(1)\n", " div_term = torch.exp(\n", " torch.arange(0, self.dims, 2, dtype=torch.float32)\n", " * (-math.log(10000.0) / self.dims)\n", " )\n", " pe[:, 0::2] = torch.sin(position * div_term)\n", " pe[:, 1::2] = torch.cos(position * div_term)\n", " pe = pe.unsqueeze(0)\n", " return pe.to(device)\n", "\n", " def forward(self, x):\n", " seq_len = x.size(1)\n", " pe = self.pe[:, :seq_len, :]\n", " x = x * math.sqrt(self.dims)\n", " x = x + pe\n", "\n", " return x\n", "\n", "\n", "def sinusoids(length, channels, max_timescale=10000):\n", " \"\"\"Returns sinusoids for positional embedding\"\"\"\n", " assert channels % 2 == 0\n", " log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)\n", " inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))\n", " scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]\n", " return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @torch.jit.script\n", "# def _manual_attention(q: Tensor, k: Tensor, v: Tensor, scale: float,\n", "# mask: Optional[Tensor] = None, ctx: int = 0, k_ctx: int = 0\n", "# ) -> Tuple[Tensor, Tensor]:\n", "# qk = (q * scale) @ ((k * scale).transpose(-1, -2))\n", "# if mask is not None:\n", "# qk = qk + mask[:ctx, :k_ctx]\n", "# qk_float = qk.float()\n", "# w = F.softmax(qk_float, dim=-1).to(q.dtype)\n", "# out = (w @ v)\n", "# return out, qk_float\n", "\n", "class MultiheadA(nn.Module):\n", " use_sdpa: bool = True\n", "\n", " def __init__(self, dims: int, heads: int):\n", " super().__init__()\n", "\n", " assert dims % heads == 0, f\"dims ({dims}) must be divisible by heads ({heads})\"\n", " assert isinstance(dims, int) and isinstance(\n", " heads, int\n", " ), \"dims and heads must be integers\"\n", "\n", " self.heads = heads\n", " self.dims = dims\n", " self.head_dim = dims // heads\n", " self.scale = (self.head_dim) ** -0.25\n", "\n", " self.query = nn.Linear(in_features=dims, out_features=dims)\n", " self.key = nn.Linear(in_features=dims, out_features=dims, bias=False)\n", " self.value = nn.Linear(in_features=dims, out_features=dims)\n", " self.out = nn.Linear(in_features=dims, out_features=dims)\n", "\n", " self._init_weights()\n", "\n", " self.register_buffer(\"_has_cuda\", torch.tensor(torch.cuda.is_available()))\n", "\n", " def _init_weights(self):\n", "\n", " std = 0.02\n", " nn.init.normal_(self.query.weight, std=std)\n", " nn.init.normal_(self.key.weight, std=std)\n", " nn.init.normal_(self.value.weight, std=std)\n", " nn.init.normal_(self.out.weight, std=std)\n", " if self.query.bias is not None:\n", " nn.init.zeros_(self.query.bias)\n", " if self.value.bias is not None:\n", " nn.init.zeros_(self.value.bias)\n", " if self.out.bias is not None:\n", " nn.init.zeros_(self.out.bias)\n", "\n", " def forward(\n", " self,\n", " x: Tensor,\n", " xa: Optional[Tensor] = None,\n", " mask: Optional[Tensor] = None,\n", " kv_cache: Optional[Dict] = None,\n", " ) -> Tuple[Tensor, Optional[Tensor]]:\n", "\n", " if __debug__:\n", " assert x.dim() == 3, f\"Expected 3D input tensor, got {x.dim()}D\"\n", " if xa is not None:\n", " assert (\n", " xa.dim() == 3\n", " ), f\"Expected 3D cross-attention tensor, got {xa.dim()}D\"\n", "\n", " q = self.query(x)\n", "\n", " if kv_cache is None or xa is None or self.key not in kv_cache:\n", " kv_input = xa if xa is not None else x\n", "\n", " k = self.key(kv_input)\n", " v = self.value(kv_input)\n", "\n", " if kv_cache is not None and xa is not None:\n", " kv_cache[self.key] = k\n", " kv_cache[self.value] = v\n", " else:\n", " k = kv_cache[self.key]\n", " v = kv_cache[self.value]\n", "\n", " wv, qk = self._attention(q=q, k=k, v=v, mask=mask)\n", "\n", " return self.out(wv), qk\n", "\n", " def _attention(\n", " self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):\n", "\n", " batch, ctx, _ = q.shape\n", " k_ctx = k.size(1)\n", "\n", " head_dim = self.dims // self.heads\n", " reshape_dim = (batch, -1, self.heads, head_dim)\n", "\n", " q = q.view(*reshape_dim).transpose(1, 2)\n", " k = k.view(batch, k_ctx, self.heads, head_dim).transpose(1, 2)\n", " v = v.view(batch, k_ctx, self.heads, head_dim).transpose(1, 2)\n", "\n", " if MultiheadA.use_sdpa:\n", " with torch.autocast(device_type=\"cuda\", enabled=True):\n", " out = F.scaled_dot_product_attention(\n", " query=q,\n", " key=k,\n", " value=v,\n", " attn_mask=None,\n", " is_causal=mask is not None and ctx > 1,\n", " )\n", " out = out.transpose(1, 2).flatten(2)\n", " return out, None\n", " else:\n", " qk = (q * self.scale) @ ((k * self.scale).transpose(-1, -2))\n", " if mask is not None:\n", " qk = qk + mask[:ctx, :k_ctx]\n", " qk_float = qk.float()\n", " w = F.softmax(qk_float, dim=-1).to(q.dtype)\n", " out = (w @ v).transpose(1, 2).flatten(2)\n", " print(\"mulita\",out.shape)\n", " return out, qk_float.detach()\n", "\n", "\n", "class MultiHeadB(nn.Module):\n", "\n", " use_sdpa: bool = True\n", "\n", " def __init__(self, dims: int, heads: int):\n", " super().__init__()\n", "\n", " if dims % heads != 0:\n", " raise ValueError(f\"dims ({dims}) must be divisible by heads ({heads})\")\n", " if not isinstance(dims, int) or not isinstance(heads, int):\n", " raise TypeError(\"dims and heads must be integers\")\n", "\n", " self.heads = heads\n", " self.dims = dims\n", " self.head_dim = dims // heads\n", "\n", " self.query = Linear(in_features=dims, out_features=dims)\n", " self.key = Linear(in_features=dims, out_features=dims, bias=False)\n", " self.value = Linear(in_features=dims, out_features=dims)\n", " self.out = Linear(in_features=dims, out_features=dims)\n", "\n", " def init_weights(self):\n", " nn.init.normal_(self.query.weight, std=0.02)\n", " nn.init.normal_(self.key.weight, std=0.02)\n", " nn.init.normal_(self.value.weight, std=0.02)\n", " nn.init.normal_(self.out.weight, std=0.02)\n", " if self.query.bias is not None:\n", " nn.init.zeros_(self.query.bias)\n", " if self.value.bias is not None:\n", " nn.init.zeros_(self.value.bias)\n", " if self.out.bias is not None:\n", " nn.init.zeros_(self.out.bias)\n", "\n", " def forward(\n", " self,\n", " x: Tensor,\n", " xa: Optional[Tensor] = None,\n", " mask: Optional[Tensor] = None,\n", " kv_cache: Optional[dict] = None,\n", " ) -> Tuple[Tensor, Optional[Tensor]]:\n", "\n", " if x.dim() != 3:\n", " raise ValueError(f\"Expected 3D input tensor, got {x.dim()}D\")\n", " if xa is not None and xa.dim() != 3:\n", " raise ValueError(f\"Expected 3D cross-attention tensor, got {xa.dim()}D\")\n", "\n", " q = self.query(x)\n", "\n", " if kv_cache is None or xa is None or self.key not in kv_cache:\n", " k = self.key(x if xa is None else xa)\n", " v = self.value(x if xa is None else xa)\n", " else:\n", " k = kv_cache[self.key]\n", " v = kv_cache[self.value]\n", "\n", " wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)\n", " return self.out(wv), qk\n", "\n", " def qkv_attention(\n", " self,\n", " q: Tensor,\n", " k: Tensor,\n", " v: Tensor,\n", " mask: Optional[Tensor] = None,\n", " ) -> Tuple[Tensor, Optional[Tensor]]:\n", "\n", " batch, ctx, dims = q.shape\n", " scale = (dims // self.heads) ** -0.25\n", "\n", " q = q.view(batch, ctx, self.heads, -1).permute(0, 2, 1, 3)\n", " k = k.view(batch, k.size(1), self.heads, -1).permute(0, 2, 1, 3)\n", " v = v.view(batch, v.size(1), self.heads, -1).permute(0, 2, 1, 3)\n", "\n", " if self.use_sdpa and torch.cuda.is_available():\n", " with torch.autocast(\"cuda\"):\n", " a = scaled_dot_product_attention(\n", " query=q, key=k, value=v, is_causal=mask is not None and ctx > 1\n", " )\n", " out = a.permute(0, 2, 1, 3).flatten(start_dim=2)\n", " qk = None\n", " else:\n", "\n", " qk = (q * scale) @ (k * scale).transpose(-1, -2)\n", " if mask is not None:\n", " qk = qk + mask[:ctx, :ctx]\n", " qk = qk.float()\n", "\n", " w = F.softmax(qk, dim=-1).to(q.dtype)\n", " out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n", " qk = qk.detach()\n", "\n", " return out, qk\n", "\n", "class MultiheadC(nn.Module):\n", " use_sdpa: bool = True\n", "\n", " def __init__(self, dims: int, heads: int, max_dist: int):\n", " super().__init__()\n", " if dims % heads != 0:\n", " raise ValueError(f\"dims ({dims}) must be divisible by heads ({heads})\")\n", " if dims % 2 != 0:\n", " raise ValueError(f\"dims ({dims}) must be even for rotary embeddings\")\n", " self.heads = heads\n", " self.head_dim = dims // heads\n", " self.dims = dims\n", " self.max_dist = max_dist\n", "\n", " scale = 1 / math.sqrt(self.head_dim)\n", " self.query = nn.Linear(in_features=dims, out_features=dims)\n", " self.key = nn.Linear(in_features=dims, out_features=dims, bias=False)\n", " self.value = nn.Linear(in_features=dims, out_features=dims)\n", " self.out = nn.Linear(in_features=dims, out_features=dims)\n", "\n", " nn.init.normal_(tensor=self.query.weight, std=scale)\n", " nn.init.normal_(tensor=self.key.weight, std=scale)\n", " nn.init.normal_(tensor=self.value.weight, std=scale)\n", " nn.init.zeros_(tensor=self.out.bias)\n", "\n", " def forward(\n", " self,\n", " x: Tensor,\n", " xa: Optional[Tensor] = None,\n", " mask: Optional[Tensor] = None,\n", " kv_cache: Optional[Dict] = None,\n", " ) -> Tuple[Tensor, Optional[Tensor]]:\n", "\n", " q = self.query(x)\n", "\n", " if kv_cache is None or xa is None or self.key not in kv_cache:\n", " k = self.key(x if xa is None else xa)\n", " v = self.value(x if xa is None else xa)\n", " else:\n", " k = kv_cache[self.key]\n", " v = kv_cache[self.value]\n", "\n", " wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)\n", " return self.out(wv), qk\n", "\n", " def qkv_attention(\n", " self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None\n", " ) -> Tuple[Tensor, Optional[Tensor]]:\n", "\n", " batch, ctx, dims = q.shape\n", " scale = (dims // self.heads) ** -0.25\n", " q = q.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)\n", " k = k.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)\n", " v = v.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)\n", "\n", " with torch.autocast(device_type=\"cuda\"):\n", " a = scaled_dot_product_attention(\n", " query=q, key=k, value=v, is_causal=mask is not None and ctx > 1\n", " )\n", " out = a.permute(0, 2, 1, 3).flatten(start_dim=2)\n", " qk = None\n", "\n", " return out, qk\n", "\n", "class miniAttention(nn.Module):\n", " def __init__(self, dims, max_dist, heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.0,\n", " proj_drop=0.0):\n", " super().__init__()\n", " if dims % heads != 0:\n", " raise ValueError(f\"dims ({dims}) must be divisible by heads ({heads})\")\n", " if dims % 2 != 0:\n", " raise ValueError(f\"dims ({dims}) must be even for rotary embeddings\")\n", " self.heads = heads\n", " self.head_dim = dims // heads\n", " self.dims = dims\n", " self.max_dist = max_dist\n", " self.scale = qk_scale or self.head_dim**-0.5\n", "\n", " self.qkv = nn.Linear(dims, dims * 3, bias=qkv_bias)\n", " self.attn_drop = nn.Dropout(attn_drop)\n", " self.proj = nn.Linear(dims, dims)\n", " self.proj_drop = nn.Dropout(proj_drop)\n", "\n", " def forward(\n", " self,\n", " x: Tensor,\n", " xa: Optional[Tensor] = None,\n", " mask: Optional[Tensor] = None,\n", " kv_cache: Optional[dict] = None,\n", " ):\n", " B, N, C = x.shape\n", " qkv = (self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4))\n", " q, k, v = qkv[0], qkv[1], qkv[2]\n", " q = q * self.scale\n", " attn = q @ k.transpose(-2, -1)\n", " attn = attn.softmax(dim=-1)\n", " attn = self.attn_drop(attn)\n", " x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n", " x = self.proj(x)\n", " x = self.proj_drop(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Refiner:\n", " def __init__(self, states, actions, alpha=0.1, gamma=0.9, epsilon=0.1):\n", " self.states = states\n", " self.actions = actions\n", " self.R = {}\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.epsilon = epsilon\n", " self.default_value = 0.0\n", "\n", " def get_value(self, state, action):\n", " return self.R.get((state, action), self.default_value)\n", "\n", " def set_value(self, state, action, value):\n", " self.R[(state, action)] = value\n", "\n", " def choose_action(self, state):\n", " if np.random.random() < self.epsilon:\n", " return np.random.randint(self.actions)\n", " else:\n", " action_values = [self.get_value(state, a) for a in range(self.actions)]\n", " return np.argmax(action_values)\n", "\n", " def update(self, state, action, reward, next_state):\n", " next_values = [self.get_value(next_state, a) for a in range(self.actions)]\n", " best_next_value = max(next_values)\n", "\n", " old_value = self.get_value(state, action)\n", " td_target = reward + self.gamma * best_next_value\n", " td_error = td_target - old_value\n", " new_value = old_value + self.alpha * td_error\n", " self.set_value(state, action, new_value)\n", "\n", "class Predictor(nn.Module):\n", " def __init__(self, dims):\n", " super().__init__()\n", " self.linear = nn.Linear(in_features=dims, out_features=1)\n", " nn.init.xavier_normal_(self.linear.weight)\n", " nn.init.zeros_(self.linear.bias)\n", "\n", " def forward(self, global_out):\n", " if global_out.dim() > 2:\n", " global_out = global_out.mean(dim=1)\n", " scale = torch.sigmoid(self.linear(global_out))\n", " \n", " return scale\n", "\n", "class AdaptiveSpan(nn.Module):\n", " def __init__(self, dims, heads, max_dist, sharpen=True, temp_scale=0.01):\n", " super().__init__()\n", " self.heads = heads\n", " self.max_dist = max_dist\n", " self.dims = dims\n", " self.temp_scale = temp_scale\n", " self.sharpen = sharpen\n", " self.span_scale = nn.Parameter(torch.tensor(1.0))\n", "\n", " self.head_dim = dims // heads\n", " self.register_buffer(\"scale\", torch.tensor(self.head_dim**-0.25))\n", "\n", " def forward(self, query, key, value, max_dist=None, max_span=None, span_scale=None):\n", " if max_dist is None:\n", " max_dist = self.max_dist\n", " if max_span is None:\n", " max_span = query.shape[1] # Default to sequence length\n", " if span_scale is None:\n", " span_scale = self.span_scale\n", " \n", " span_mean = span_scale.mean().item()\n", " span_len = min(int(max_span * span_mean), query.shape[1], key.shape[1], value.shape[1])\n", " eff_span = min(span_len, max_dist)\n", " \n", " if eff_span == 0:\n", " batch_size = query.shape[0]\n", " return (torch.zeros(batch_size, eff_span, self.dims, device=query.device), None)\n", " \n", " q_span = query[:, :eff_span, :]\n", " k_span = key[:, :eff_span, :]\n", " v_span = value[:, :eff_span, :]\n", "\n", " batch_size = q_span.shape[0]\n", "\n", " reshape_dims = (batch_size, -1, self.heads, self.head_dim)\n", " q = q_span.view(*reshape_dims).permute(0, 2, 1, 3)\n", " k = k_span.view(*reshape_dims).permute(0, 2, 1, 3)\n", " v = v_span.view(*reshape_dims).permute(0, 2, 1, 3)\n", "\n", " with torch.autocast(device_type=\"cuda\", enabled=torch.cuda.is_available()):\n", " temperature = (\n", " 1.0 + self.temp_scale * (1.0 - span_mean)\n", " if self.sharpen\n", " else 0.5 + self.temp_scale * span_mean\n", " )\n", " scores = torch.matmul(q, k.transpose(-2, -1))\n", " weights = torch.softmax((scores / temperature) * self.scale, dim=-1)\n", " out = torch.matmul(weights, v)\n", " out = out.permute(0, 2, 1, 3).reshape(batch_size, eff_span, self.dims)\n", "\n", " return out, weights\n", "\n", "class FocusA(nn.Module):\n", " def __init__(self, dims, heads, max_dist, sharpen=True, win_size=256, max_span=512):\n", " super().__init__()\n", " self.heads = heads\n", " self.max_dist = max_dist\n", " self.dims = dims\n", " self.max_span = max_span\n", " self.sliding_window = win_size\n", " self.temp_scale = 0.01\n", " self.sharpen = sharpen\n", " self.head_dim = dims // heads\n", " self.batch_size = None # Will be set during forward pass\n", "\n", " self.refiner = Refiner(\n", " states=10000, actions=10, alpha=0.1, gamma=0.9, epsilon=0.1\n", " )\n", " self.span_pred = Predictor(dims=dims)\n", " self.attn_local = AdaptiveSpan(\n", " dims=dims, heads=heads, max_dist=max_dist, sharpen=True, temp_scale=0.01\n", " )\n", " self.attn_global = MultiheadC(dims=dims, heads=heads, max_dist=max_dist)\n", "\n", " self.projection = nn.Linear(in_features=2 * dims, out_features=dims)\n", "\n", " self.ln_a = nn.LayerNorm(normalized_shape=dims)\n", " self.ln_b = nn.LayerNorm(normalized_shape=dims)\n", "\n", " mask = torch.empty(max_span, max_span).fill_(float(\"-inf\")).triu_(diagonal=1)\n", " self.register_buffer(\"mask\", mask, persistent=False)\n", "\n", " self.register_buffer(\"window_mask\", None, persistent=False)\n", " self.register_buffer(\"threshold\", torch.tensor(1e-4), persistent=False)\n", " self.register_buffer(\"s_factor\", torch.tensor(0.1), persistent=False)\n", "\n", " def forward(self, x, xa=None, mask=None, kv_cache=None):\n", " if mask is None:\n", " mask = self.mask\n", " \n", " local = self.ln_a(x)\n", " globe = self.ln_b(x)\n", "\n", " globe_out, _ = self.attn_global(globe, globe, globe)\n", " base_scale = self.span_pred(globe_out)\n", " state = self.extract(local)\n", "\n", " action = self.refiner.choose_action(state=state)\n", " refine = self.action_scale(action=action)\n", "\n", " span_scale = torch.clamp(base_scale * refine, min=0.0, max=1.0)\n", " span_mean = span_scale.mean().item()\n", "\n", " with torch.no_grad():\n", " current_win_size = max(1, int(self.sliding_window * span_mean))\n", " current_span_len = max(1, int(self.max_span * span_mean))\n", "\n", " effective_max = min(self.max_dist, local.size(1))\n", " local_max = min(self.max_dist, current_span_len, current_win_size)\n", " globe_max = effective_max\n", "\n", " self.attn_local.max_dist = local_max\n", " self.attn_global.max_dist = globe_max\n", "\n", " local_out = self.slide_win(\n", " x=local,\n", " win_size=current_win_size,\n", " span_len=current_span_len,\n", " span_scale=span_scale,\n", " mask=mask,\n", " )\n", " with torch.no_grad():\n", " quality = self.quality(output=local_out)\n", " next_state = self.extract(local_out)\n", " self.refiner.update(\n", " state=state, action=action, reward=quality, next_state=next_state)\n", " combined = torch.cat([local_out, globe_out], dim=-1)\n", " x = self.projection(combined)\n", " return x\n", "\n", " def quality(self, output):\n", " with torch.no_grad():\n", " safe_output = output.clamp(min=1e-10)\n", " entropy = -(safe_output * torch.log(safe_output)).sum(-1).mean()\n", " coverage = (output > 0.01).float().mean()\n", " return float(coverage - 0.1 * entropy)\n", "\n", " def extract(self, x):\n", " with torch.no_grad():\n", " mean_state = x.mean(dim=(0, 1))\n", " var_state = x.var(dim=(0, 1), unbiased=False)\n", " state = torch.cat([mean_state, var_state])\n", " state_id = self.discretize(state.cpu().numpy())\n", " return state_id\n", "\n", " def discretize(self, state):\n", " bins = np.linspace(-1, 1, num=10)\n", " state_discrete = np.digitize(state, bins)\n", " state_hash = hash(tuple(state_discrete))\n", " state_id = state_hash % (self.refiner.states - 1)\n", " return state_id\n", "\n", " def action_scale(self, action):\n", " span_value = action / (self.refiner.actions - 1)\n", " device = next(self.parameters()).device\n", " dtype = next(self.parameters()).dtype\n", " span_scale = torch.tensor([span_value], device=device, dtype=dtype)\n", " return span_scale\n", "\n", " def _focus(self, query, key, value, span_scale, mask):\n", " max_iterations = 10\n", " iteration = 0\n", " prev_attn = torch.zeros_like(input=query)\n", " attn_out = torch.zeros_like(input=query)\n", " attn_weights = None\n", "\n", " threshold = self.threshold.item()\n", " s_factor = self.s_factor.item()\n", "\n", " while iteration < max_iterations:\n", " span_len = int(self.max_span * span_scale.mean().item())\n", " span_len = min(span_len, query.size(1), key.size(1), value.size(1))\n", " eff_span = min(span_len, self.max_dist)\n", "\n", " if eff_span == 0:\n", " break\n", "\n", " q_span = query[:, :eff_span, :]\n", " k_span = key[:, :eff_span, :]\n", " v_span = value[:, :eff_span, :]\n", "\n", " batch_size, seq_len, dims = q_span.size()\n", " d_k = dims // self.heads\n", " scale_factor = 1 / math.sqrt(d_k)\n", "\n", " q = q_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)\n", " k = k_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)\n", " v = v_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)\n", "\n", " if self.sharpen:\n", " temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())\n", " else:\n", " temperature = 0.5 + self.temp_scale * span_scale.mean().item()\n", " attn_scores = (\n", " torch.matmul(q, k.transpose(-2, -1)) * scale_factor / temperature\n", " )\n", " if mask.size(-2) != attn_scores.size(-2) or mask.size(\n", " -1\n", " ) != attn_scores.size(-1):\n", "\n", " mask_q_len = min(mask.size(-2), attn_scores.size(-2))\n", " mask_k_len = min(mask.size(-1), attn_scores.size(-1))\n", " resized_mask = torch.ones(\n", " (\n", " batch_size,\n", " self.heads,\n", " attn_scores.size(-2),\n", " attn_scores.size(-1),\n", " ),\n", " device=mask.device,\n", " dtype=mask.dtype,\n", " )\n", " resized_mask[:, :, :mask_q_len, :mask_k_len] = mask[\n", " :, :, :mask_q_len, :mask_k_len\n", " ]\n", " mask = resized_mask\n", "\n", " attn_scores = attn_scores.masked_fill(mask == 0, float(\"-inf\"))\n", " attn_weights = torch.softmax(attn_scores, dim=-1)\n", " attn_out = torch.matmul(attn_weights, v)\n", " attn_out = (\n", " attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)\n", " )\n", "\n", " diff = torch.abs(attn_out - prev_attn).mean()\n", " dynamic_threshold = threshold + s_factor * diff\n", "\n", " if diff < dynamic_threshold:\n", " break\n", "\n", " prev_attn = attn_out\n", " query = query + attn_out\n", " iteration += 1\n", " return attn_out, attn_weights\n", "\n", " def slide_win(self, x, win_size, span_len, span_scale, mask):\n", " batch_size, seq_len, dims = x.size()\n", " self.batch_size = batch_size\n", " num_windows = (seq_len + win_size - 1) // win_size\n", " output = torch.zeros_like(x)\n", " device = x.device\n", " default_mask = None\n", "\n", " for i in range(num_windows):\n", " start_idx = i * win_size\n", " end_idx = min((i + 1) * win_size, seq_len)\n", " window_size = end_idx - start_idx\n", "\n", " key_start = max(0, start_idx - span_len + win_size)\n", " key_end = min(start_idx + span_len, seq_len)\n", " span_size = key_end - key_start\n", "\n", " query = x[:, start_idx:end_idx, :]\n", " key = x[:, key_start:key_end, :]\n", " value = key\n", "\n", " if mask is not None:\n", " if mask.dim() == 4:\n", " window_mask = mask[:, :, start_idx:end_idx, key_start:key_end]\n", " if window_mask.size(1) == 1:\n", " window_mask = window_mask.expand(-1, self.heads, -1, -1)\n", " else:\n", " if (\n", " default_mask is None\n", " or default_mask.size(-2) != window_size\n", " or default_mask.size(-1) != span_size\n", " ):\n", " default_mask = torch.ones(\n", " (batch_size, self.heads, window_size, span_size),\n", " device=device,\n", " dtype=torch.bool,\n", " )\n", " window_mask = default_mask\n", " else:\n", " if (\n", " default_mask is None\n", " or default_mask.size(-2) != window_size\n", " or default_mask.size(-1) != span_size\n", " ):\n", " default_mask = torch.ones(\n", " (batch_size, self.heads, window_size, span_size),\n", " device=device,\n", " dtype=torch.bool,\n", " )\n", " window_mask = default_mask\n", "\n", " attn_out, _ = self._focus(\n", " query=query,\n", " key=key,\n", " value=value,\n", " span_scale=span_scale,\n", " mask=window_mask,\n", " )\n", "\n", " output[:, start_idx:end_idx, :] = attn_out\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Residual(nn.Module):\n", " def __init__(\n", " self, param: Dimensions, dims: int, heads: int, dropout: float, activation: str\n", " ):\n", " super().__init__()\n", " self.param = param\n", " self.dims = dims\n", "\n", " activation_map = {\n", " \"gelu\": nn.GELU(),\n", " \"relu\": nn.ReLU(),\n", " \"sigmoid\": nn.Sigmoid(),\n", " \"tanh\": nn.Tanh(),\n", " \"leaky_relu\": nn.LeakyReLU(),\n", " \"elu\": nn.ELU(),\n", " }\n", " act_fn = activation_map.get(activation, nn.ReLU())\n", "\n", " self.attn = MultiheadA(dims=dims, heads=heads)\n", " self.cross = MultiHeadB(dims=dims, heads=heads)\n", "\n", " self.mlp = nn.Sequential(\n", " nn.Dropout(p=dropout),\n", " nn.Linear(in_features=dims, out_features=dims * 4, bias=True),\n", " act_fn,\n", " nn.Dropout(p=dropout),\n", " nn.Linear(in_features=dims * 4, out_features=dims, bias=True),\n", " )\n", "\n", " self.ln_a = nn.LayerNorm(normalized_shape=dims)\n", " self.ln_b = nn.LayerNorm(normalized_shape=dims)\n", " self.ln_c = nn.LayerNorm(normalized_shape=dims)\n", "\n", " self._init_weights()\n", "\n", " def _init_weights(self):\n", "\n", " for m in self.mlp:\n", " if isinstance(m, nn.Linear):\n", " nn.init.kaiming_normal_(m.weight)\n", " if m.bias is not None:\n", " nn.init.zeros_(m.bias)\n", "\n", " def forward(\n", " self,\n", " x: Tensor,\n", " xa: Optional[Tensor] = None,\n", " mask: Optional[Tensor] = None,\n", " kv_cache: Optional[Dict[str, Tensor]] = None,\n", " ) -> Tensor:\n", " \n", " y = x\n", " z = self.ln_a(x)\n", " x = x + self.attn(z, mask=mask, kv_cache=kv_cache)[0]\n", " if xa is not None:\n", " z = self.ln_b(x)\n", " x = x + self.cross(z, xa, mask=mask, kv_cache=kv_cache)[0]\n", " x = x + self.mlp(self.ln_c(x))\n", "\n", " return x + y\n", " \n", "class AudioEncoder(nn.Module):\n", " def __init__(self, param: Dimensions, mels: int, ctx: int, dims: int, heads: int, \n", " checkpoint: bool, dropout: float, activation: str, layerA: int, layerB: int):\n", " super().__init__()\n", " \n", " self.checkpoint = checkpoint\n", "\n", " act_map = {\n", " \"gelu\": nn.GELU(),\n", " \"relu\": nn.ReLU(),\n", " \"sigmoid\": nn.Sigmoid(),\n", " \"tanh\": nn.Tanh(),\n", " \"leaky_relu\": nn.LeakyReLU(),\n", " \"elu\": nn.ELU(),\n", " }\n", " act = act_map.get(activation, nn.ReLU())\n", "\n", " self.rotation = rotary(ctx=ctx, dims=dims, heads=heads, base=10000)\n", " self.position = sinusoids(length=ctx, channels=dims)\n", " self.register_buffer(\"positions\", self.position, persistent=False)\n", "\n", " self.convx = nn.Sequential(\n", " nn.Conv1d(mels, dims, kernel_size=3, padding=1, bias=False),\n", " nn.BatchNorm1d(dims),\n", " act,\n", " nn.Dropout(p=dropout),\n", " nn.Conv1d(dims, dims, kernel_size=3, stride=2, padding=1, bias=False),\n", " nn.BatchNorm1d(dims),\n", " act,\n", " nn.Dropout(p=dropout),\n", " )\n", "\n", " for m in self.convx:\n", " if isinstance(m, nn.Conv1d):\n", " nn.init.kaiming_normal_(m.weight)\n", "\n", " self.blockA = nn.ModuleList([\n", " Residual(param, dims, heads, dropout, activation) \n", " for _ in range(layerA)]) if layerA > 0 else None\n", "\n", " self.blockB = nn.ModuleList([\n", " FocusA(dims=dims, heads=heads, max_dist=ctx) \n", " for _ in range(layerB)]) if layerB > 0 else None\n", "\n", " self.ln_post = nn.LayerNorm(dims)\n", "\n", " def forward(self, x) -> Tensor:\n", " x = checkpoint(self._forward, x, use_reentrant=True) if self.checkpoint else self._forward(x)\n", " for block in chain(self.blockB or [], self.blockA or []):\n", " x = checkpoint(block, x, use_reentrant=True) if self.checkpoint else block(x) \n", " return self.ln_post(x)\n", "\n", " def _forward(self, x) -> Tensor:\n", " x = F.gelu(self.convx(x))\n", " x = x.permute(0, 2, 1) \n", " x = (x + self.positions).to(x.dtype) # type: ignore\n", " x = self.rotation(x)\n", " return x\n", "\n", "class TextDecoder(nn.Module):\n", " def __init__(self, param: Dimensions, vocab: int, ctx: int, dims: int, heads: int, \n", " checkpoint: bool, dropout: float, activation: str, layerA: int, layerB: int):\n", " super().__init__()\n", " \n", " self.checkpoint = checkpoint\n", " self.token_embedding = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)\n", " nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)\n", " \n", " self.positional_embedding = nn.Parameter(data=torch.empty(ctx, dims))\n", " nn.init.normal_(tensor=self.positional_embedding, mean=0.0, std=0.02)\n", " \n", " self.positional_encoding = PositionalEncoding(ctx=ctx, dims=dims)\n", " self.ln = LayerNorm(normalized_shape=dims)\n", " \n", " self.blockA = nn.ModuleList(modules=[Residual(param=param, dims=dims, heads=heads, \n", " dropout=dropout, activation=activation) \n", " for _ in range(layerA)]) if layerA > 0 else None\n", "\n", " self.blockB = nn.ModuleList(modules=[FocusA(dims=dims, heads=heads, max_dist=ctx) \n", " for _ in range(layerB)]) if layerB > 0 else None\n", "\n", " mask = torch.empty(ctx, ctx).fill_(value=-np.inf).triu_(diagonal=1)\n", " self.register_buffer(name=\"mask\", tensor=mask, persistent=False)\n", " self.mask = mask\n", " \n", " def forward(self, x, xa, kv_cache = None):\n", " x = checkpoint(function=self._forward, x=x, xa=xa, kv_cache=kv_cache) if self.checkpoint else self._forward(x=x, xa=xa, kv_cache=kv_cache)\n", " for block in chain(self.blockA or [], self.blockB or []): \n", " x = checkpoint(function=block, x=x, xa=xa, mask=self.mask, kv_cache=kv_cache) if self.checkpoint else block(x=x, xa=xa, mask=self.mask, kv_cache=kv_cache)\n", " x = self.ln(x)\n", " x = (x @ torch.transpose(self.token_embedding.weight.to(dtype=x.dtype), dim0=0, dim1=1)).float()\n", " return x\n", " \n", " def _forward(self, x, xa, kv_cache): \n", " offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n", " x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]])\n", " x = self.positional_encoding(x)\n", " x = x.to(dtype=xa.dtype)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class Dimensions:\n", " vocab: int\n", " text_ctx: int\n", " text_state: int\n", " text_head: int\n", " text_layerA: int\n", " text_layerB: int\n", " audio_ctx: int\n", " audio_state: int\n", " audio_head: int\n", " audio_layerA: int\n", " audio_layerB: int\n", " mels: int\n", " checkpoint: bool = False\n", " dropout: float = 0.1\n", " activation: str = \"gelu\"\n", "\n", "class Echo(nn.Module):\n", "\n", " PAD_TOKEN_ID = 50257\n", " START_TOKEN_ID = 50258\n", "\n", " def __init__(self, param: Dimensions):\n", " super().__init__()\n", " self.param = param\n", "\n", " self._build_model()\n", "\n", " self.to(self.device)\n", "\n", " def _build_model(self):\n", "\n", " self.encoder = AudioEncoder(\n", " param=self.param,\n", " mels=self.param.mels,\n", " ctx=self.param.audio_ctx,\n", " dims=self.param.audio_state,\n", " heads=self.param.audio_head,\n", " layerA=self.param.audio_layerA,\n", " layerB=self.param.audio_layerB,\n", " checkpoint=self.param.checkpoint,\n", " dropout=self.param.dropout,\n", " activation=self.param.activation,\n", " )\n", "\n", " self.decoder = TextDecoder(\n", " param=self.param,\n", " vocab=self.param.vocab,\n", " ctx=self.param.text_ctx,\n", " dims=self.param.text_state,\n", " heads=self.param.text_head,\n", " layerA=self.param.text_layerA,\n", " layerB=self.param.text_layerB,\n", " checkpoint=self.param.checkpoint,\n", " dropout=self.param.dropout,\n", " activation=self.param.activation,\n", " )\n", "\n", " @property\n", " def device(self) -> torch.device:\n", "\n", " return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", " @staticmethod\n", " def shift_tokens_right(\n", " input_ids: torch.Tensor,\n", " pad_token_id: int = PAD_TOKEN_ID,\n", " decoder_start_token_id: int = START_TOKEN_ID,\n", " ) -> torch.Tensor:\n", " \"\"\" Shift input tokens right for teacher forcing. Returns: Shifted input tokens \"\"\"\n", " batch_size, seq_len = input_ids.shape\n", " shifted_input_ids = torch.zeros_like(input_ids)\n", " shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()\n", " shifted_input_ids[:, 0] = decoder_start_token_id\n", " shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n", " return shifted_input_ids\n", "\n", " def forward(\n", " self,\n", " input_features: torch.Tensor,\n", " labels: Optional[torch.Tensor] = None,\n", " dec_input_ids: Optional[torch.Tensor] = None,\n", " ) -> Dict[str, Optional[torch.Tensor]]:\n", "\n", " if labels is not None and dec_input_ids is None:\n", " dec_input_ids = self.shift_tokens_right(\n", " input_ids=labels,\n", " pad_token_id=self.PAD_TOKEN_ID,\n", " decoder_start_token_id=self.START_TOKEN_ID,\n", " )\n", "\n", " with torch.autocast(device_type=\"cuda\", enabled=torch.cuda.is_available()):\n", " encoded_features = self.encoder(input_features)\n", "\n", " logits = self.decoder(dec_input_ids, encoded_features)\n", "\n", " loss = None\n", " if labels is not None:\n", " loss_fct = nn.CrossEntropyLoss(ignore_index=-100)\n", " labels = labels.to(logits.device).long()\n", "\n", " flattened_logits = logits.view(-1, self.param.vocab)\n", " flattened_labels = labels.view(-1)\n", "\n", " loss = loss_fct(flattened_logits, flattened_labels)\n", "\n", " return {\"loss\": loss, \"logits\": logits}\n", "\n", " def _init_weights(self, module):\n", " std = 0.02\n", "\n", " if isinstance(module, (nn.Linear, nn.Conv1d)):\n", " nn.init.normal_(module.weight, mean=0.0, std=std)\n", " if module.bias is not None:\n", " nn.init.zeros_(module.bias)\n", "\n", " elif isinstance(module, nn.Embedding):\n", " nn.init.normal_(module.weight, mean=0.0, std=std)\n", " if module.padding_idx is not None:\n", " module.weight.data[module.padding_idx].zero_()\n", "\n", " elif isinstance(module, AudioEncoder):\n", " module.convx.apply(self._init_weights)\n", "\n", " elif isinstance(module, TextDecoder):\n", " nn.init.normal_(module.positional_embedding, mean=0.0, std=std)\n", " nn.init.normal_(module.token_embedding.weight, mean=0.0, std=std)\n", "\n", " elif isinstance(module, Residual):\n", " for layer in module.mlp:\n", " if isinstance(layer, nn.Linear):\n", " nn.init.normal_(layer.weight, std=std)\n", " nn.init.zeros_(layer.bias)\n", "\n", " for ln_name in [\"ln_a\", \"ln_b\", \"ln_c\"]:\n", " if hasattr(module, ln_name):\n", " ln = getattr(module, ln_name)\n", " nn.init.normal_(ln.weight, mean=1.0, std=std)\n", " nn.init.zeros_(ln.bias)\n", "\n", " if hasattr(module, \"attn\") and hasattr(module.attn, \"init_weights\"):\n", " module.attn.init_weights()\n", " if hasattr(module, \"cross\") and hasattr(module.cross, \"init_weights\"):\n", " module.cross.init_weights()\n", "\n", " def init_weights(self):\n", " self.apply(self._init_weights)\n", "\n", " @torch.no_grad()\n", " def generate(\n", " self,\n", " audio_features: torch.Tensor,\n", " max_length: int = 100,\n", " temperature: float = 1.0,\n", " ) -> torch.Tensor:\n", " encoded_features = self.encoder(audio_features).to(self.device)\n", "\n", " batch_size = audio_features.size(0)\n", " generated = torch.full(\n", " (batch_size, 1), self.START_TOKEN_ID, dtype=torch.long, device=self.device)\n", "\n", " kv_cache = {}\n", "\n", " for _ in range(max_length - 1):\n", " logits = self.decoder(generated, encoded_features, kv_cache=kv_cache)\n", " next_token_logits = logits[:, -1, :] / max(temperature, 1e-7)\n", " probs = F.softmax(next_token_logits, dim=-1)\n", " next_tokens = torch.multinomial(probs, num_samples=1)\n", " generated = torch.cat([generated, next_tokens], dim=-1)\n", " if (next_tokens == self.PAD_TOKEN_ID).all():\n", " break\n", "\n", " return generated" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", "\n", "log_dir = os.path.join(\"./output/Whisper\", datetime.now().strftime(format=\"%m-%d_%H\"))\n", "os.makedirs(name=log_dir, exist_ok=True)\n", "\n", "param = Dimensions(\n", " mels=128,\n", " audio_ctx=1500,\n", " audio_head=8,\n", " audio_layerA=8,\n", " audio_layerB=2,\n", " audio_state=1024,\n", " vocab=51865,\n", " text_ctx=448,\n", " text_head=8,\n", " text_layerA=8,\n", " text_layerB=0,\n", " text_state=1024,\n", " checkpoint=False,\n", " dropout=0.001,\n", " activation=\"gelu\",\n", ")\n", "\n", "model = Echo(param=param).to(device=device)\n", "model.init_weights()\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "class MaxFactor(torch.optim.Optimizer):\n", " def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0, \n", " weight_decay=0.01, gamma=0.99, eps_rms=1e-8, maximize=False):\n", " \n", " defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, \n", " gamma=gamma, eps_rms=eps_rms, maximize=maximize)\n", " super().__init__(params=params, defaults=defaults)\n", "\n", " def _get_lr(self, param_group, param_state):\n", " step = param_state[\"step\"]\n", " min_step = 1e-5 * step\n", " rel_step_sz = min(min_step, 1.0 / step.sqrt())\n", " param_scale = max(param_group[\"eps\"][1], param_state[\"RMS\"])\n", " return param_scale * rel_step_sz\n", "\n", " @staticmethod\n", " def _rms(tensor):\n", " return tensor.norm() / (tensor.numel() ** 0.5)\n", "\n", " @torch.no_grad()\n", " def step(self, closure=None):\n", " loss = None\n", " if closure is not None:\n", " with torch.enable_grad():\n", " loss = closure()\n", "\n", " for group in self.param_groups:\n", " params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []\n", " eps1, eps2 = group[\"eps\"]\n", " for p in group[\"params\"]:\n", " if p.grad is None:\n", " continue\n", " grad = p.grad\n", " if grad.dtype in {torch.float16, torch.bfloat16}:\n", " grad = grad.float()\n", "\n", " state = self.state[p]\n", " if len(state) == 0:\n", " state[\"step\"] = torch.tensor(0.0, dtype=torch.float32)\n", " if p.grad.dim() > 1:\n", " row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)\n", " row_shape[-1], col_shape[-2] = 1, 1\n", " state[\"row_var\"], state[\"col_var\"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)\n", " state[\"v\"] = torch.zeros_like(p, memory_format=torch.preserve_format)\n", "\n", " row_vars.append(state.get(\"row_var\", None))\n", " col_vars.append(state.get(\"col_var\", None))\n", " v.append(state[\"v\"])\n", " state_steps.append(state[\"step\"])\n", " params_with_grad.append(p)\n", " grads.append(grad)\n", "\n", " for i, param in enumerate(params_with_grad):\n", " grad = grads[i]\n", "\n", " if group[\"maximize\"]:\n", " grad = -grad\n", " step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]\n", "\n", " if eps1 is None:\n", " eps1 = torch.finfo(param.dtype).eps\n", " \n", " step_t += 1\n", " step_float = step_t.item()\n", " one_minus_beta2_t = step_float ** group[\"beta2_decay\"]\n", " rho_t = min(group[\"lr\"], 1 / (step_float ** 0.5))\n", " alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t\n", "\n", " if group[\"weight_decay\"]!= 0:\n", " param.mul_(1 - group[\"lr\"] * group[\"weight_decay\"])\n", "\n", " if grad.dim() > 1:\n", " row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)\n", " row_var.lerp_(row_mean, one_minus_beta2_t)\n", " col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)\n", " col_var.lerp_(col_mean, one_minus_beta2_t)\n", " var_estimate = row_var @ col_var\n", " max_row_var = row_var.max(dim=-2, keepdim=True)[0] \n", " var_estimate.div_(max_row_var.clamp_(min=eps1))\n", " else:\n", " vi.mul_(group[\"gamma\"]).add_(grad ** 2, alpha=1 - group[\"gamma\"])\n", " var_estimate = vi\n", "\n", " update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)\n", " update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))\n", " denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group[\"d\"]))\n", " param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])\n", " return loss\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "token=\"\"\n", "\n", "extractor = WhisperFeatureExtractor.from_pretrained(\n", " pretrained_model_name_or_path=\"openai/whisper-small\", token=token,\n", " feature_size=128, sampling_rate=16000, return_tensors=\"pt\", do_normalize=True)\n", "\n", "tokenizer = WhisperTokenizerFast.from_pretrained(\n", " pretrained_model_name_or_path=\"openai/whisper-small\", \n", " language=\"en\", task=\"transcribe\", token=token)\n", "\n", "processor = WhisperProcessor.from_pretrained(\n", " pretrained_model_name_or_path=\"openai/whisper-small\", token=token)\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", " extractor: Any\n", " tokenizer: Any\n", " decoder_start_token_id: int\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], Tensor]]]) -> Dict[str, Tensor]:\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.extractor.pad(input_features, return_tensors=\"pt\")\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " labels_batch = self.tokenizer.pad(label_features, return_tensors=\"pt\")\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", " batch[\"labels\"] = labels\n", " return batch\n", " \n", "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", " batch[\"input_features\"] = extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n", " return batch\n", "\n", "data_collator = DataCollatorSpeechSeq2SeqWithPadding(\n", " processor=processor, extractor=extractor,\n", " tokenizer=tokenizer, decoder_start_token_id=50258)\n", "\n", "dataset = IterableDatasetDict()\n", "\n", "dataset[\"train\"] = load_dataset(\n", " path=\"mozilla-foundation/common_voice_17_0\", split=\"train\",\n", " name=\"en\", streaming=True, token=token, \n", " trust_remote_code=True, save_infos=True)#.shuffle()#.take(10000)\n", "\n", "dataset[\"test\"] = load_dataset(\n", " path=\"mozilla-foundation/common_voice_17_0\",\n", " name=\"en\", split=\"test\", streaming=True, \n", " token=token, trust_remote_code=True, save_infos=True).take(500)\n", "\n", "dataset = dataset.cast_column(column=\"audio\", feature=Audio(sampling_rate=16000))\n", "\n", "dataset = dataset.map(function=prepare_dataset, \n", " remove_columns=list(next(iter(dataset.values()))\n", " .features)).with_format(type=\"torch\")\n", "\n", "metric = evaluate.load(path=\"wer\")\n", "\n", "def compute_metrics(eval_pred):\n", " pred_logits = eval_pred.predictions\n", " label_ids = eval_pred.label_ids\n", "\n", " if isinstance(pred_logits, tuple):\n", " pred_ids = pred_logits[0]\n", " else:\n", " pred_ids = pred_logits\n", " if pred_ids.ndim == 3:\n", " pred_ids = np.argmax(pred_ids, axis=-1)\n", "\n", " label_ids[label_ids == -100] = tokenizer.pad_token_id\n", " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str) # type: ignore\n", " pred_flat = pred_ids.flatten()\n", " labels_flat = label_ids.flatten()\n", " mask = labels_flat != tokenizer.pad_token_id\n", " \n", " if len(pred_str) > 0:\n", " sample_idx = random.randint(0, len(pred_str) - 1)\n", " print(\"-\" * 10)\n", " print(f\"Prediction: {pred_str[sample_idx]}\")\n", " print(f\"Label: {label_str[sample_idx]}\")\n", " print(\"-\" * 10)\n", "\n", " acc = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])\n", " pre = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], \n", " average='weighted', zero_division=0)\n", " rec = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], \n", " average='weighted', zero_division=0)\n", " f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], \n", " average='weighted', zero_division=0)\n", " \n", " return {\n", " \"wer\": wer,\n", " \"accuracy\": acc,\n", " \"precision\": pre,\n", " \"recall\": rec,\n", " \"f1\": f1}\n", " \n", "log_dir = os.path.join(os.getcwd(), \"whisper_training_logs\")\n", "os.makedirs(log_dir, exist_ok=True)\n", "\n", "args = Seq2SeqTrainingArguments(\n", " output_dir=log_dir,\n", " per_device_train_batch_size=1,\n", " per_device_eval_batch_size=1,\n", " gradient_accumulation_steps=1,\n", " eval_accumulation_steps=1,\n", " tf32=True,\n", " bf16=True,\n", " eval_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " max_steps=100000,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " warmup_steps=300,\n", " num_train_epochs=1,\n", " logging_steps=1,\n", " logging_dir=os.path.join(log_dir, \"logs_hf\"),\n", " report_to=[\"tensorboard\"],\n", " push_to_hub=False,\n", " disable_tqdm=False,\n", " save_total_limit=1,\n", " remove_unused_columns=False,\n", " label_names=[\"labels\"],\n", " eval_on_start=False,\n", " # optim=\"adafactor\",\n", ")\n", "\n", "optimizer = MaxFactor(\n", " model.parameters(), \n", " lr=0.01, \n", " beta2_decay=-0.8,\n", " eps=(1e-10, 1e-4), \n", " d=1.0,\n", " weight_decay=0.01, \n", " gamma=0.99, \n", " eps_rms=1e-8,\n", " maximize=False,\n", ")\n", "\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer=optimizer,\n", " T_max=args.max_steps,\n", " eta_min=0.0,\n", " last_epoch=-1 \n", ")\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=args,\n", " model=model,\n", " train_dataset=dataset[\"train\"],\n", " eval_dataset=dataset[\"test\"],\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " processing_class=extractor,\n", " optimizers=(optimizer, scheduler),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train(resume_from_checkpoint=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }