model-editing / nn.py
Charles Lin
Add algorithms from efk codebase
e56055d
raw
history blame
11.2 kB
import torch
import torch.nn as nn
import logging
import time
from utils import factorization
LOG = logging.getLogger(__name__)
class FixableDropout(nn.Module):
def __init__(self, p: float):
super().__init__()
self.p = p
self.mask_cache = {}
self.seed = 0
def resample(self, seed=None):
if seed is None:
seed = int(time.time() * 1e6)
self.mask_cache = {}
self.seed = seed
def forward(self, x):
if self.training:
if x.shape not in self.mask_cache:
generator = torch.Generator(x.device).manual_seed(self.seed)
self.mask_cache[x.shape] = torch.bernoulli(
torch.full_like(x, 1 - self.p), generator=generator
).bool()
self.should_resample = False
x = (self.mask_cache[x.shape] * x) / (1 - self.p)
return x
def extra_repr(self) -> str:
return f"p={self.p}"
class ActMLP(nn.Module):
def __init__(self, hidden_dim, n_hidden):
super().__init__()
self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id")
def forward(self, x):
return self.mlp(x.view(-1, 1)).view(x.shape)
class LightIDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
):
super().__init__()
LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}")
self.layer1 = nn.Linear(indim, rank)
self.layer2 = nn.Linear(rank, indim)
self.layer2.weight.data[:] = 0
self.layer2.bias = None
def forward(self, x):
h = self.layer1(x).relu()
return x + self.layer2(h)
class IDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
n_modes: int = None
):
super().__init__()
LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}")
self.layers = nn.ModuleList(
[
LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes)
for idx in range(n_hidden + 1)
]
)
def forward(self, x, mode=None):
for layer in self.layers:
x = layer(x, mode=mode)
return x
class LatentIDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
):
super().__init__()
LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}")
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(indim, rank))
for _ in range(n_hidden - 1):
self.layers.append(nn.Linear(rank, rank))
self.layers.append(nn.Linear(rank, outdim))
for layer in self.layers[:-1]:
nn.init.xavier_normal_(layer.weight.data)
if init == "id":
self.layers[-1].weight.data.zero_()
self.layers[-1].bias.data.zero_()
self.init = init
def forward(self, x):
out = x
for layer in self.layers[:-1]:
out = layer(out).relu()
out = self.layers[-1](out)
if self.init == "id":
return out + x
else:
return out
class KLinear(nn.Module):
def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True):
super().__init__()
self.inf = inf
in_fact = factorization(inf)
out_fact = factorization(outf)
total_params = 0
self.a, self.b = nn.ParameterList(), nn.ParameterList()
for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)):
new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1)
if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0:
break
total_params += new_params
self.a.append(nn.Parameter(torch.empty(o1, i1)))
if symmetric:
self.a.append(nn.Parameter(torch.empty(o2, i2)))
self.b.append(nn.Parameter(torch.empty(o2, i2)))
if symmetric:
self.b.append(nn.Parameter(torch.empty(o1, i1)))
assert self.a[-1].kron(self.b[-1]).shape == (outf, inf)
for factor in self.a:
nn.init.kaiming_normal_(factor.data)
for factor in self.b:
if zero_init:
factor.data.zero_()
else:
nn.init.kaiming_normal_(factor.data)
print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps")
self.bias = nn.Parameter(torch.zeros(outf))
def forward(self, x):
assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}"
w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5)
y = w @ x.T
if self.bias is not None:
y = y + self.bias
return y
class LRLinear(nn.Module):
def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None):
super().__init__()
mid_dim = min(rank, inf)
if init == "id":
self.u = nn.Parameter(torch.zeros(outf, mid_dim))
self.v = nn.Parameter(torch.randn(mid_dim, inf))
elif init == "xavier":
self.u = nn.Parameter(torch.empty(outf, mid_dim))
self.v = nn.Parameter(torch.empty(mid_dim, inf))
nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.v.data, gain=1.0)
else:
raise ValueError(f"Unrecognized initialization {init}")
if n_modes is not None:
self.mode_shift = nn.Embedding(n_modes, outf)
self.mode_shift.weight.data.zero_()
self.mode_scale = nn.Embedding(n_modes, outf)
self.mode_scale.weight.data.fill_(1)
self.n_modes = n_modes
self.bias = nn.Parameter(torch.zeros(outf))
self.inf = inf
self.init = init
def forward(self, x, mode=None):
if mode is not None:
assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it"
assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}"
assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})"
pre_act = (self.u @ (self.v @ x.T)).T
if self.bias is not None:
pre_act += self.bias
if mode is not None:
if not isinstance(mode, torch.Tensor):
mode = torch.tensor(mode).to(x.device)
scale, shift = self.mode_scale(mode), self.mode_shift(mode)
pre_act = pre_act * scale + shift
# need clamp instead of relu so gradient at 0 isn't 0
acts = pre_act.clamp(min=0)
if self.init == "id":
return acts + x
else:
return acts
class MLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = "xavier_uniform",
act: str = "relu",
rank: int = None,
):
super().__init__()
self.init = init
if act == "relu":
self.act = nn.ReLU()
elif act == "learned":
self.act = ActMLP(10, 1)
else:
raise ValueError(f"Unrecognized activation function '{act}'")
if hidden_dim is None:
hidden_dim = outdim * 2
if init.startswith("id") and outdim != indim:
LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})")
outdim = indim
if init == "id":
old_hidden_dim = hidden_dim
if hidden_dim < indim * 2:
hidden_dim = indim * 2
if hidden_dim % indim != 0:
hidden_dim += hidden_dim % indim
if old_hidden_dim != hidden_dim:
LOG.info(
f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}"
)
if init == "id_alpha":
self.alpha = nn.Parameter(torch.zeros(1, outdim))
dims = [indim] + [hidden_dim] * n_hidden + [outdim]
LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})")
layers = []
for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])):
if rank is None:
layers.append(nn.Linear(ind, outd))
else:
layers.append(LRLinear(ind, outd, rank=rank))
if idx < n_hidden:
layers.append(self.act)
if rank is None:
if init == "id":
if n_hidden > 0:
layers[0].weight.data = torch.eye(indim).repeat(
hidden_dim // indim, 1
)
layers[0].weight.data[hidden_dim // 2:] *= -1
layers[-1].weight.data = torch.eye(outdim).repeat(
1, hidden_dim // outdim
)
layers[-1].weight.data[:, hidden_dim // 2:] *= -1
layers[-1].weight.data /= (hidden_dim // indim) / 2.0
for layer in layers:
if isinstance(layer, nn.Linear):
if init == "ortho":
nn.init.orthogonal_(layer.weight)
elif init == "id":
if layer.weight.shape[0] == layer.weight.shape[1]:
layer.weight.data = torch.eye(hidden_dim)
else:
gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0
nn.init.xavier_uniform_(layer.weight, gain=gain)
layer.bias.data[:] = 0
layers[-1].bias = None
self.mlp = nn.Sequential(*layers)
def forward(self, x):
if self.init == "id_alpha":
return x + self.alpha * self.mlp(x)
else:
return self.mlp(x)
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
m0 = MLP(1000, 1000, 1500, 3)
m1 = MLP(1000, 1000, 1500, 3, init="id")
m2 = MLP(1000, 1000, 1500, 3, init="id_alpha")
m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned")
x = 0.01 * torch.randn(999, 1000)
y0 = m0(x)
y1 = m1(x)
y2 = m2(x)
y3 = m3(x)
print("y0", (y0 - x).abs().max())
print("y1", (y1 - x).abs().max())
print("y2", (y2 - x).abs().max())
print("y3", (y3 - x).abs().max())
assert not torch.allclose(y0, x)
assert torch.allclose(y1, x)
assert torch.allclose(y2, x)
assert not torch.allclose(y3, x)
import pdb; pdb.set_trace() # fmt: skip