Spaces:
Running
Running
import torch.nn as nn | |
from models.positional_embeddings import FourierEmbedding, PositionalEmbedding | |
from models.networks.transformers import FusedMLP | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from einops import rearrange | |
class TimeEmbedder(nn.Module): | |
def __init__( | |
self, | |
noise_embedding_type: str, | |
dim: int, | |
time_scaling: float, | |
expansion: int = 4, | |
): | |
super().__init__() | |
self.encode_time = ( | |
PositionalEmbedding(num_channels=dim, endpoint=True) | |
if noise_embedding_type == "positional" | |
else FourierEmbedding(num_channels=dim) | |
) | |
self.time_scaling = time_scaling | |
self.map_time = nn.Sequential( | |
nn.Linear(dim, dim * expansion), | |
nn.SiLU(), | |
nn.Linear(dim * expansion, dim * expansion), | |
) | |
def forward(self, t): | |
time = self.encode_time(t * self.time_scaling) | |
time_mean = time.mean(dim=-1, keepdim=True) | |
time_std = time.std(dim=-1, keepdim=True) | |
time = (time - time_mean) / time_std | |
return self.map_time(time) | |
def get_timestep_embedding(timesteps, embedding_dim, dtype=torch.float32): | |
assert len(timesteps.shape) == 1 | |
timesteps = timesteps * 1000.0 | |
half_dim = embedding_dim // 2 | |
emb = np.log(10000) / (half_dim - 1) | |
emb = (torch.arange(half_dim, dtype=dtype, device=timesteps.device) * -emb).exp() | |
emb = timesteps.to(dtype)[:, None] * emb[None, :] | |
emb = torch.cat([emb.sin(), emb.cos()], dim=-1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = F.pad(emb, (0, 1)) | |
assert emb.shape == (timesteps.shape[0], embedding_dim) | |
return emb | |
class AdaLNMLPBlock(nn.Module): | |
def __init__(self, dim, expansion): | |
super().__init__() | |
self.mlp = FusedMLP( | |
dim, dropout=0.0, hidden_layer_multiplier=expansion, activation=nn.GELU | |
) | |
self.ada_map = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 3)) | |
self.ln = nn.LayerNorm(dim, elementwise_affine=False) | |
nn.init.zeros_(self.mlp[-1].weight) | |
nn.init.zeros_(self.mlp[-1].bias) | |
def forward(self, x, y): | |
gamma, mu, sigma = self.ada_map(y).chunk(3, dim=-1) | |
x_res = (1 + gamma) * self.ln(x) + mu | |
x = x + self.mlp(x_res) * sigma | |
return x | |
class GeoAdaLNMLP(nn.Module): | |
def __init__(self, input_dim, dim, depth, expansion, cond_dim): | |
super().__init__() | |
self.time_embedder = TimeEmbedder("positional", dim // 4, 1000, expansion=4) | |
self.cond_mapper = nn.Linear(cond_dim, dim) | |
self.initial_mapper = nn.Linear(input_dim, dim) | |
self.blocks = nn.ModuleList( | |
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)] | |
) | |
self.final_adaln = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(dim, dim * 2), | |
) | |
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) | |
self.final_linear = nn.Linear(dim, input_dim) | |
def forward(self, batch): | |
x = batch["y"] | |
x = self.initial_mapper(x) | |
gamma = batch["gamma"] | |
cond = batch["emb"] | |
t = self.time_embedder(gamma) | |
cond = self.cond_mapper(cond) | |
cond = cond + t | |
for block in self.blocks: | |
x = block(x, cond) | |
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) | |
x = (1 + gamma_last) * self.final_ln(x) + mu_last | |
x = self.final_linear(x) | |
return x | |
class GeoAdaLNMLPVonFisher(nn.Module): | |
def __init__(self, input_dim, dim, depth, expansion, cond_dim): | |
super().__init__() | |
self.cond_mapper = nn.Linear(cond_dim, dim) | |
self.blocks = nn.ModuleList( | |
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)] | |
) | |
self.final_adaln = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(dim, dim * 2), | |
) | |
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) | |
self.mu_predictor = nn.Sequential( | |
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), | |
nn.Linear(dim, input_dim), | |
) | |
self.kappa_predictor = nn.Sequential( | |
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), | |
nn.Linear(dim, 1), | |
torch.nn.Softplus(), | |
) | |
self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) | |
torch.nn.init.trunc_normal_( | |
self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 | |
) | |
def forward(self, batch): | |
cond = batch["emb"] | |
cond = self.cond_mapper(cond) | |
x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) | |
for block in self.blocks: | |
x = block(x, cond) | |
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) | |
x = (1 + gamma_last) * self.final_ln(x) + mu_last | |
mu = self.mu_predictor(x) | |
mu = mu / mu.norm(dim=-1, keepdim=True) | |
kappa = self.kappa_predictor(x) | |
return mu, kappa | |
class GeoAdaLNMLPVonFisherMixture(nn.Module): | |
def __init__(self, input_dim, dim, depth, expansion, cond_dim, num_mixtures=3): | |
super().__init__() | |
self.cond_mapper = nn.Linear(cond_dim, dim) | |
self.blocks = nn.ModuleList( | |
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)] | |
) | |
self.final_adaln = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(dim, dim * 2), | |
) | |
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) | |
self.mu_predictor = nn.Sequential( | |
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), | |
nn.Linear(dim, input_dim * num_mixtures), | |
) | |
self.kappa_predictor = nn.Sequential( | |
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), | |
nn.Linear(dim, num_mixtures), | |
torch.nn.Softplus(), | |
) | |
self.mixture_weights = nn.Sequential( | |
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), | |
nn.Linear(dim, num_mixtures), | |
torch.nn.Softmax(dim=-1), | |
) | |
self.num_mixtures = num_mixtures | |
self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) | |
torch.nn.init.trunc_normal_( | |
self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 | |
) | |
def forward(self, batch): | |
cond = batch["emb"] | |
cond = self.cond_mapper(cond) | |
x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) | |
for block in self.blocks: | |
x = block(x, cond) | |
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) | |
x = (1 + gamma_last) * self.final_ln(x) + mu_last | |
mu = self.mu_predictor(x) | |
mu = rearrange(mu, "b (n d) -> b n d", n=self.num_mixtures) | |
mu = mu / mu.norm(dim=-1, keepdim=True) | |
kappa = self.kappa_predictor(x) | |
weights = self.mixture_weights(x) | |
return mu, kappa, weights | |