File size: 5,899 Bytes
7dcff9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
from torch import nn
from torchdiffeq import odeint
import wandb
import math
class ODELinear(nn.Module):
def __init__(
self,
dim: int,
factor,
act,
**kwargs
):
super().__init__()
self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim))
self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2))
self.dim = dim
if act == "tanh":
self.act = torch.nn.Tanh()
elif act == "silu":
self.act = torch.nn.SiLU()
else:
raise ValueError(f"act must be one of ['tanh', 'silu'], got {act}")
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.ode_up_proj, a=math.sqrt(5))
nn.init.zeros_(self.ode_down_proj)
def get_time_embedding(self, t, base=10000, device='cuda', dtype=torch.float32):
if t < 1:
alpha = 1
else:
alpha = 2*t-1
ntk_base = base * alpha ** (self.dim / (self.dim-2))
ntk_inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
index = torch.arange(0, self.dim, 2, dtype=torch.float32).to(device)
delta_ntk_freq = -2*index/(self.dim-2) * 1 / (base ** (index/self.dim) * (alpha ** (index/(self.dim-2) + 1)))
return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype)
def forward(self, t, x: torch.Tensor):
device = x.device
delta_time, time = self.get_time_embedding(t.to(device), device=device, dtype=x.dtype)
x = x + torch.log(time)
time_embed = delta_time / time
delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float()
delta_inv_freq = delta_inv_freq + time_embed
return delta_inv_freq
class CLEXScalingRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=1000000, device=None) -> None:
super().__init__()
self.max_t = rope_scaling["max_factor"]
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.proj_func = ODELinear(dim, rope_scaling["param_factor"], rope_scaling["act"])
self.rope_cached = None
self.max_t_cached = 0
self.freq_cached = None
self.time_dt = rope_scaling["time_dt"]
self.ode_args = {
"method": "rk4",
"options": {"step_size": self.time_dt},
}
def sample_random_times(self, max_t, device):
return torch.randint(1, max_t, (1,), dtype = torch.long, device=device)
def get_random_position_ids(self, n=2048, max=8192):
positions = torch.randperm(max)[:n].sort().values
return positions
def get_continuous_freq(self, time_grid, ex_positions, device):
solution = odeint(
self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args
)
if time_grid.size(0) == 2:
scale_inv_freq = torch.exp(solution[1])
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
else:
scale_inv_freq = torch.exp(solution)
return scale_inv_freq
embed = torch.cat((freqs,freqs), dim=-1)
return embed
def forward(self, input_embeds, seq_len, do_train=False):
device = self.proj_func.ode_up_proj.device
dtype = input_embeds.dtype
scale_factor = seq_len // self.max_position_embeddings
if do_train:
t_val = self.sample_random_times(self.max_t+1, device)[0]
if scale_factor < 1.0:
scale_factor = 1
sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float()
ex_positions = torch.cat([
torch.tensor([0]),
(sampled_position_ids + 1) / scale_factor,
torch.tensor([seq_len*t_val//scale_factor-1])]
).to(device, dtype=torch.float32)
else:
t_val = scale_factor if seq_len%self.max_position_embeddings == 0.0 else scale_factor + 1
t_val = t_val if t_val <= self.max_t else self.max_t
ex_positions = torch.arange(0, self.max_position_embeddings * t_val, dtype=torch.float32).to(device)
if t_val == 1.0:
scale_inv_freq = self.inv_freq.to(device)
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
cos, sin = embed.cos(), embed.sin()
elif do_train:
time_grid = torch.tensor([1.0, t_val]).float().to(device)
embed = self.get_continuous_freq(time_grid, ex_positions, device)
cos, sin = embed.cos(), embed.sin()
else:
if self.freq_cached is None:
time_grid = torch.arange(1.0, self.max_t+1.0, dtype=torch.float32).to(device)
self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device)
if t_val != self.max_t_cached:
scale_inv_freq = self.freq_cached[int(t_val-1.0)]
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
self.rope_cached = torch.cat((embed.cos()[None, :, :], embed.sin()[None, :, :]), dim=0)
self.max_t_cached = t_val
cos, sin = self.rope_cached
return torch.cat(
(cos[None, :seq_len].to(dtype=dtype),
sin[None, :seq_len].to(dtype=dtype)),
dim=0
)
|