import torch.nn as nn from models.positional_embeddings import FourierEmbedding, PositionalEmbedding from models.networks.transformers import FusedMLP import torch import torch.nn.functional as F import numpy as np from einops import rearrange class TimeEmbedder(nn.Module): def __init__( self, noise_embedding_type: str, dim: int, time_scaling: float, expansion: int = 4, ): super().__init__() self.encode_time = ( PositionalEmbedding(num_channels=dim, endpoint=True) if noise_embedding_type == "positional" else FourierEmbedding(num_channels=dim) ) self.time_scaling = time_scaling self.map_time = nn.Sequential( nn.Linear(dim, dim * expansion), nn.SiLU(), nn.Linear(dim * expansion, dim * expansion), ) def forward(self, t): time = self.encode_time(t * self.time_scaling) time_mean = time.mean(dim=-1, keepdim=True) time_std = time.std(dim=-1, keepdim=True) time = (time - time_mean) / time_std return self.map_time(time) def get_timestep_embedding(timesteps, embedding_dim, dtype=torch.float32): assert len(timesteps.shape) == 1 timesteps = timesteps * 1000.0 half_dim = embedding_dim // 2 emb = np.log(10000) / (half_dim - 1) emb = (torch.arange(half_dim, dtype=dtype, device=timesteps.device) * -emb).exp() emb = timesteps.to(dtype)[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1)) assert emb.shape == (timesteps.shape[0], embedding_dim) return emb class AdaLNMLPBlock(nn.Module): def __init__(self, dim, expansion): super().__init__() self.mlp = FusedMLP( dim, dropout=0.0, hidden_layer_multiplier=expansion, activation=nn.GELU ) self.ada_map = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 3)) self.ln = nn.LayerNorm(dim, elementwise_affine=False) nn.init.zeros_(self.mlp[-1].weight) nn.init.zeros_(self.mlp[-1].bias) def forward(self, x, y): gamma, mu, sigma = self.ada_map(y).chunk(3, dim=-1) x_res = (1 + gamma) * self.ln(x) + mu x = x + self.mlp(x_res) * sigma return x class GeoAdaLNMLP(nn.Module): def __init__(self, input_dim, dim, depth, expansion, cond_dim): super().__init__() self.time_embedder = TimeEmbedder("positional", dim // 4, 1000, expansion=4) self.cond_mapper = nn.Linear(cond_dim, dim) self.initial_mapper = nn.Linear(input_dim, dim) self.blocks = nn.ModuleList( [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] ) self.final_adaln = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 2), ) self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) self.final_linear = nn.Linear(dim, input_dim) def forward(self, batch): x = batch["y"] x = self.initial_mapper(x) gamma = batch["gamma"] cond = batch["emb"] t = self.time_embedder(gamma) cond = self.cond_mapper(cond) cond = cond + t for block in self.blocks: x = block(x, cond) gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) x = (1 + gamma_last) * self.final_ln(x) + mu_last x = self.final_linear(x) return x class GeoAdaLNMLPVonFisher(nn.Module): def __init__(self, input_dim, dim, depth, expansion, cond_dim): super().__init__() self.cond_mapper = nn.Linear(cond_dim, dim) self.blocks = nn.ModuleList( [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] ) self.final_adaln = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 2), ) self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) self.mu_predictor = nn.Sequential( FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), nn.Linear(dim, input_dim), ) self.kappa_predictor = nn.Sequential( FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), nn.Linear(dim, 1), torch.nn.Softplus(), ) self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) torch.nn.init.trunc_normal_( self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 ) def forward(self, batch): cond = batch["emb"] cond = self.cond_mapper(cond) x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) for block in self.blocks: x = block(x, cond) gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) x = (1 + gamma_last) * self.final_ln(x) + mu_last mu = self.mu_predictor(x) mu = mu / mu.norm(dim=-1, keepdim=True) kappa = self.kappa_predictor(x) return mu, kappa class GeoAdaLNMLPVonFisherMixture(nn.Module): def __init__(self, input_dim, dim, depth, expansion, cond_dim, num_mixtures=3): super().__init__() self.cond_mapper = nn.Linear(cond_dim, dim) self.blocks = nn.ModuleList( [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] ) self.final_adaln = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 2), ) self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) self.mu_predictor = nn.Sequential( FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), nn.Linear(dim, input_dim * num_mixtures), ) self.kappa_predictor = nn.Sequential( FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), nn.Linear(dim, num_mixtures), torch.nn.Softplus(), ) self.mixture_weights = nn.Sequential( FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), nn.Linear(dim, num_mixtures), torch.nn.Softmax(dim=-1), ) self.num_mixtures = num_mixtures self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) torch.nn.init.trunc_normal_( self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 ) def forward(self, batch): cond = batch["emb"] cond = self.cond_mapper(cond) x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) for block in self.blocks: x = block(x, cond) gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) x = (1 + gamma_last) * self.final_ln(x) + mu_last mu = self.mu_predictor(x) mu = rearrange(mu, "b (n d) -> b n d", n=self.num_mixtures) mu = mu / mu.norm(dim=-1, keepdim=True) kappa = self.kappa_predictor(x) weights = self.mixture_weights(x) return mu, kappa, weights