|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from einops import rearrange |
|
import opt_einsum as oe |
|
contract = oe.contract |
|
|
|
""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ |
|
|
|
class OptimModule(nn.Module): |
|
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ |
|
|
|
def register(self, name, tensor, lr=None, wd=0.0): |
|
"""Register a tensor with a configurable learning rate and 0 weight decay""" |
|
|
|
if lr == 0.0: |
|
self.register_buffer(name, tensor) |
|
else: |
|
self.register_parameter(name, nn.Parameter(tensor)) |
|
|
|
optim = {} |
|
if lr is not None: optim["lr"] = lr |
|
if wd is not None: optim["weight_decay"] = wd |
|
setattr(getattr(self, name), "_optim", optim) |
|
|
|
|
|
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): |
|
|
|
seqlen = u.shape[-1] |
|
|
|
fft_size = 2 * seqlen |
|
k_f = torch.fft.rfft(k, n=fft_size) / fft_size |
|
if k_rev is not None: |
|
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size |
|
k_f = k_f + k_rev_f.conj() |
|
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) |
|
|
|
if len(u.shape) > 3: |
|
k_f = k_f.unsqueeze(1) |
|
|
|
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] |
|
|
|
out = y + u * D |
|
|
|
if gelu: |
|
out = F.gelu(out) |
|
if dropout_mask is not None: |
|
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) |
|
else: |
|
return out.to(dtype=u.dtype) |
|
|
|
|
|
@torch.jit.script |
|
def mul_sum(q, y): |
|
return (q * y).sum(dim=1) |
|
|
|
|
|
class Sin(nn.Module): |
|
def __init__(self, dim, w=10, w_mod=1, train_freq=True): |
|
super().__init__() |
|
|
|
init_tensor = torch.ones(1, dim) |
|
self.freq = ( |
|
nn.Parameter(w * init_tensor) |
|
if train_freq |
|
else w * torch.ones(1, dim) |
|
) |
|
self.w_mod = w_mod |
|
|
|
def forward(self, x): |
|
return torch.sin(self.w_mod * self.freq * x) |
|
|
|
|
|
class PositionalEmbedding(OptimModule): |
|
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs): |
|
"""Complex exponential positional embeddings for Hyena filters.""" |
|
super().__init__() |
|
|
|
self.seq_len = seq_len |
|
|
|
t = torch.linspace(0, 1, self.seq_len)[None, :, None] |
|
|
|
if emb_dim > 1: |
|
bands = (emb_dim - 1) // 2 |
|
|
|
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] |
|
w = 2 * math.pi * t_rescaled / seq_len |
|
|
|
f = torch.linspace(1e-4, bands - 1, bands)[None, None] |
|
z = torch.exp(-1j * f * w) |
|
z = torch.cat([t, z.real, z.imag], dim=-1) |
|
self.register("z", z, lr=lr_pos_emb) |
|
self.register("t", t, lr=0.0) |
|
|
|
def forward(self, L): |
|
return self.z[:, :L], self.t[:, :L] |
|
|
|
|
|
class ExponentialModulation(OptimModule): |
|
def __init__( |
|
self, |
|
d_model, |
|
fast_decay_pct=0.3, |
|
slow_decay_pct=1.5, |
|
target=1e-2, |
|
modulation_lr=0.0, |
|
shift: float = 0.0, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.shift = shift |
|
max_decay = math.log(target) / fast_decay_pct |
|
min_decay = math.log(target) / slow_decay_pct |
|
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] |
|
self.register("deltas", deltas, lr=modulation_lr) |
|
|
|
def forward(self, t, x): |
|
decay = torch.exp(-t * self.deltas.abs()) |
|
x = x * (decay + self.shift) |
|
return x |
|
|
|
|
|
class HyenaFilter(OptimModule): |
|
def __init__( |
|
self, |
|
d_model, |
|
emb_dim=3, |
|
order=16, |
|
seq_len=1024, |
|
lr=1e-3, |
|
lr_pos_emb=1e-5, |
|
dropout=0.0, |
|
w=1, |
|
w_mod=1, |
|
wd=0, |
|
bias=True, |
|
num_inner_mlps=2, |
|
linear_mixer=False, |
|
modulate: bool = True, |
|
normalized=False, |
|
bidirectional=False, |
|
**kwargs, |
|
): |
|
""" |
|
Implicit long filter with modulation. |
|
|
|
Args: |
|
d_model: number of channels in the input |
|
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands |
|
order: width of the FFN |
|
num_inner_mlps: number of inner linear layers inside filter MLP |
|
|
|
Note: |
|
filter_dropout is not implemented |
|
""" |
|
super().__init__() |
|
|
|
self.d_model=d_model |
|
self.emb_dim=emb_dim |
|
self.seq_len=seq_len |
|
self.modulate=modulate |
|
self.use_bias = bias |
|
self.bidirectional = bidirectional |
|
|
|
self.bias = nn.Parameter(torch.randn(self.d_model)) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
act = Sin(dim=order, w=w, w_mod=w_mod) |
|
assert ( |
|
emb_dim % 2 != 0 and emb_dim >= 3 |
|
), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" |
|
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) |
|
|
|
|
|
if linear_mixer is False: |
|
self.implicit_filter = nn.Sequential( |
|
nn.Linear(emb_dim, order), |
|
act, |
|
) |
|
for i in range(num_inner_mlps): |
|
self.implicit_filter.append(nn.Linear(order, order)) |
|
self.implicit_filter.append(act) |
|
self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) |
|
else: |
|
self.implicit_filter = nn.Sequential( |
|
nn.Linear(emb_dim, d_model, bias=False), |
|
) |
|
|
|
if self.bidirectional: |
|
self.implicit_filter_rev = nn.Sequential( |
|
nn.Linear(emb_dim, order), |
|
act, |
|
) |
|
for i in range(num_inner_mlps): |
|
self.implicit_filter_rev.append(nn.Linear(order, order)) |
|
self.implicit_filter_rev.append(act) |
|
self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False)) |
|
|
|
self.modulation = ExponentialModulation(d_model, **kwargs) |
|
|
|
self.normalized = normalized |
|
for c in self.implicit_filter.children(): |
|
for name, v in c.state_dict().items(): |
|
optim = {"weight_decay": wd, "lr": lr} |
|
setattr(getattr(c, name), "_optim", optim) |
|
|
|
def filter(self, L, *args, **kwargs): |
|
z, t = self.pos_emb(L) |
|
h = self.implicit_filter(z) |
|
if self.modulate: |
|
h = self.modulation(t, h) |
|
if self.normalized: |
|
h = h / torch.norm(h, dim=-1, p=1, keepdim=True) |
|
return h |
|
|
|
def filter_rev(self, L, *args, **kwargs): |
|
z, t = self.pos_emb(L) |
|
h = self.implicit_filter_rev(z) |
|
if self.modulate: |
|
h = self.modulation(t, h) |
|
if self.normalized: |
|
h = h / torch.norm(h, dim=-1, p=1, keepdim=True) |
|
return h |
|
|
|
def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs): |
|
if k_fwd is None: |
|
k_fwd = self.filter(L) |
|
if self.bidirectional and k_rev is None: |
|
k_rev = self.filter_rev(L) |
|
|
|
|
|
k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd |
|
if bias is None: |
|
bias = self.bias |
|
bias = bias if self.use_bias else 0 * bias |
|
|
|
if self.bidirectional: |
|
k_rev = k_rev[0] if type(k_rev) is tuple else k_rev |
|
k = F.pad(k_fwd, (0, L)) \ |
|
+ F.pad(k_rev.flip(-1), (L, 0)) |
|
else: |
|
k = k_fwd |
|
|
|
|
|
y = fftconv_ref( |
|
x, |
|
k, |
|
bias, |
|
dropout_mask=None, |
|
gelu=False, |
|
) |
|
|
|
return y.to(dtype=x.dtype) |