Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import logging | |
import time | |
from utils import factorization | |
LOG = logging.getLogger(__name__) | |
class FixableDropout(nn.Module): | |
def __init__(self, p: float): | |
super().__init__() | |
self.p = p | |
self.mask_cache = {} | |
self.seed = 0 | |
def resample(self, seed=None): | |
if seed is None: | |
seed = int(time.time() * 1e6) | |
self.mask_cache = {} | |
self.seed = seed | |
def forward(self, x): | |
if self.training: | |
if x.shape not in self.mask_cache: | |
generator = torch.Generator(x.device).manual_seed(self.seed) | |
self.mask_cache[x.shape] = torch.bernoulli( | |
torch.full_like(x, 1 - self.p), generator=generator | |
).bool() | |
self.should_resample = False | |
x = (self.mask_cache[x.shape] * x) / (1 - self.p) | |
return x | |
def extra_repr(self) -> str: | |
return f"p={self.p}" | |
class ActMLP(nn.Module): | |
def __init__(self, hidden_dim, n_hidden): | |
super().__init__() | |
self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id") | |
def forward(self, x): | |
return self.mlp(x.view(-1, 1)).view(x.shape) | |
class LightIDMLP(nn.Module): | |
def __init__( | |
self, | |
indim: int, | |
outdim: int, | |
hidden_dim: int, | |
n_hidden: int, | |
init: str = None, | |
act: str = None, | |
rank: int = None, | |
): | |
super().__init__() | |
LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}") | |
self.layer1 = nn.Linear(indim, rank) | |
self.layer2 = nn.Linear(rank, indim) | |
self.layer2.weight.data[:] = 0 | |
self.layer2.bias = None | |
def forward(self, x): | |
h = self.layer1(x).relu() | |
return x + self.layer2(h) | |
class IDMLP(nn.Module): | |
def __init__( | |
self, | |
indim: int, | |
outdim: int, | |
hidden_dim: int, | |
n_hidden: int, | |
init: str = None, | |
act: str = None, | |
rank: int = None, | |
n_modes: int = None | |
): | |
super().__init__() | |
LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}") | |
self.layers = nn.ModuleList( | |
[ | |
LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes) | |
for idx in range(n_hidden + 1) | |
] | |
) | |
def forward(self, x, mode=None): | |
for layer in self.layers: | |
x = layer(x, mode=mode) | |
return x | |
class LatentIDMLP(nn.Module): | |
def __init__( | |
self, | |
indim: int, | |
outdim: int, | |
hidden_dim: int, | |
n_hidden: int, | |
init: str = None, | |
act: str = None, | |
rank: int = None, | |
): | |
super().__init__() | |
LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}") | |
self.layers = nn.ModuleList() | |
self.layers.append(nn.Linear(indim, rank)) | |
for _ in range(n_hidden - 1): | |
self.layers.append(nn.Linear(rank, rank)) | |
self.layers.append(nn.Linear(rank, outdim)) | |
for layer in self.layers[:-1]: | |
nn.init.xavier_normal_(layer.weight.data) | |
if init == "id": | |
self.layers[-1].weight.data.zero_() | |
self.layers[-1].bias.data.zero_() | |
self.init = init | |
def forward(self, x): | |
out = x | |
for layer in self.layers[:-1]: | |
out = layer(out).relu() | |
out = self.layers[-1](out) | |
if self.init == "id": | |
return out + x | |
else: | |
return out | |
class KLinear(nn.Module): | |
def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True): | |
super().__init__() | |
self.inf = inf | |
in_fact = factorization(inf) | |
out_fact = factorization(outf) | |
total_params = 0 | |
self.a, self.b = nn.ParameterList(), nn.ParameterList() | |
for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)): | |
new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1) | |
if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0: | |
break | |
total_params += new_params | |
self.a.append(nn.Parameter(torch.empty(o1, i1))) | |
if symmetric: | |
self.a.append(nn.Parameter(torch.empty(o2, i2))) | |
self.b.append(nn.Parameter(torch.empty(o2, i2))) | |
if symmetric: | |
self.b.append(nn.Parameter(torch.empty(o1, i1))) | |
assert self.a[-1].kron(self.b[-1]).shape == (outf, inf) | |
for factor in self.a: | |
nn.init.kaiming_normal_(factor.data) | |
for factor in self.b: | |
if zero_init: | |
factor.data.zero_() | |
else: | |
nn.init.kaiming_normal_(factor.data) | |
print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps") | |
self.bias = nn.Parameter(torch.zeros(outf)) | |
def forward(self, x): | |
assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}" | |
w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5) | |
y = w @ x.T | |
if self.bias is not None: | |
y = y + self.bias | |
return y | |
class LRLinear(nn.Module): | |
def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None): | |
super().__init__() | |
mid_dim = min(rank, inf) | |
if init == "id": | |
self.u = nn.Parameter(torch.zeros(outf, mid_dim)) | |
self.v = nn.Parameter(torch.randn(mid_dim, inf)) | |
elif init == "xavier": | |
self.u = nn.Parameter(torch.empty(outf, mid_dim)) | |
self.v = nn.Parameter(torch.empty(mid_dim, inf)) | |
nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu")) | |
nn.init.xavier_uniform_(self.v.data, gain=1.0) | |
else: | |
raise ValueError(f"Unrecognized initialization {init}") | |
if n_modes is not None: | |
self.mode_shift = nn.Embedding(n_modes, outf) | |
self.mode_shift.weight.data.zero_() | |
self.mode_scale = nn.Embedding(n_modes, outf) | |
self.mode_scale.weight.data.fill_(1) | |
self.n_modes = n_modes | |
self.bias = nn.Parameter(torch.zeros(outf)) | |
self.inf = inf | |
self.init = init | |
def forward(self, x, mode=None): | |
if mode is not None: | |
assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it" | |
assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}" | |
assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})" | |
pre_act = (self.u @ (self.v @ x.T)).T | |
if self.bias is not None: | |
pre_act += self.bias | |
if mode is not None: | |
if not isinstance(mode, torch.Tensor): | |
mode = torch.tensor(mode).to(x.device) | |
scale, shift = self.mode_scale(mode), self.mode_shift(mode) | |
pre_act = pre_act * scale + shift | |
# need clamp instead of relu so gradient at 0 isn't 0 | |
acts = pre_act.clamp(min=0) | |
if self.init == "id": | |
return acts + x | |
else: | |
return acts | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
indim: int, | |
outdim: int, | |
hidden_dim: int, | |
n_hidden: int, | |
init: str = "xavier_uniform", | |
act: str = "relu", | |
rank: int = None, | |
): | |
super().__init__() | |
self.init = init | |
if act == "relu": | |
self.act = nn.ReLU() | |
elif act == "learned": | |
self.act = ActMLP(10, 1) | |
else: | |
raise ValueError(f"Unrecognized activation function '{act}'") | |
if hidden_dim is None: | |
hidden_dim = outdim * 2 | |
if init.startswith("id") and outdim != indim: | |
LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})") | |
outdim = indim | |
if init == "id": | |
old_hidden_dim = hidden_dim | |
if hidden_dim < indim * 2: | |
hidden_dim = indim * 2 | |
if hidden_dim % indim != 0: | |
hidden_dim += hidden_dim % indim | |
if old_hidden_dim != hidden_dim: | |
LOG.info( | |
f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}" | |
) | |
if init == "id_alpha": | |
self.alpha = nn.Parameter(torch.zeros(1, outdim)) | |
dims = [indim] + [hidden_dim] * n_hidden + [outdim] | |
LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})") | |
layers = [] | |
for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])): | |
if rank is None: | |
layers.append(nn.Linear(ind, outd)) | |
else: | |
layers.append(LRLinear(ind, outd, rank=rank)) | |
if idx < n_hidden: | |
layers.append(self.act) | |
if rank is None: | |
if init == "id": | |
if n_hidden > 0: | |
layers[0].weight.data = torch.eye(indim).repeat( | |
hidden_dim // indim, 1 | |
) | |
layers[0].weight.data[hidden_dim // 2:] *= -1 | |
layers[-1].weight.data = torch.eye(outdim).repeat( | |
1, hidden_dim // outdim | |
) | |
layers[-1].weight.data[:, hidden_dim // 2:] *= -1 | |
layers[-1].weight.data /= (hidden_dim // indim) / 2.0 | |
for layer in layers: | |
if isinstance(layer, nn.Linear): | |
if init == "ortho": | |
nn.init.orthogonal_(layer.weight) | |
elif init == "id": | |
if layer.weight.shape[0] == layer.weight.shape[1]: | |
layer.weight.data = torch.eye(hidden_dim) | |
else: | |
gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0 | |
nn.init.xavier_uniform_(layer.weight, gain=gain) | |
layer.bias.data[:] = 0 | |
layers[-1].bias = None | |
self.mlp = nn.Sequential(*layers) | |
def forward(self, x): | |
if self.init == "id_alpha": | |
return x + self.alpha * self.mlp(x) | |
else: | |
return self.mlp(x) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s", | |
level=logging.INFO, | |
) | |
m0 = MLP(1000, 1000, 1500, 3) | |
m1 = MLP(1000, 1000, 1500, 3, init="id") | |
m2 = MLP(1000, 1000, 1500, 3, init="id_alpha") | |
m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned") | |
x = 0.01 * torch.randn(999, 1000) | |
y0 = m0(x) | |
y1 = m1(x) | |
y2 = m2(x) | |
y3 = m3(x) | |
print("y0", (y0 - x).abs().max()) | |
print("y1", (y1 - x).abs().max()) | |
print("y2", (y2 - x).abs().max()) | |
print("y3", (y3 - x).abs().max()) | |
assert not torch.allclose(y0, x) | |
assert torch.allclose(y1, x) | |
assert torch.allclose(y2, x) | |
assert not torch.allclose(y3, x) | |
import pdb; pdb.set_trace() # fmt: skip | |