from inspect import isfunction from typing import Callable, Optional import torch from einops import rearrange from einops.layers.torch import Rearrange from torch import nn from .t_cond_mlp import ( AdaptiveLayerNorm1D, FrequencyEmbedder, normalization_layer, ) # from .vit import Attention, FeedForward def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class PreNorm(nn.Module): def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): super().__init__() self.norm = normalization_layer(norm, dim, norm_cond_dim) self.fn = fn def forward(self, x: torch.Tensor, *args, **kwargs): if isinstance(self.norm, AdaptiveLayerNorm1D): return self.fn(self.norm(x, *args), **kwargs) else: return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class CrossAttention(nn.Module): def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) context_dim = default(context_dim, dim) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x, context=None): context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) q = self.to_q(x) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Transformer(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0, norm: str = "layer", norm_cond_dim: int = -1, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) ff = FeedForward(dim, mlp_dim, dropout=dropout) self.layers.append( nn.ModuleList( [ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), ] ) ) def forward(self, x: torch.Tensor, *args): for attn, ff in self.layers: x = attn(x, *args) + x x = ff(x, *args) + x return x class TransformerCrossAttn(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0, norm: str = "layer", norm_cond_dim: int = -1, context_dim: Optional[int] = None, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) ca = CrossAttention( dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout ) ff = FeedForward(dim, mlp_dim, dropout=dropout) self.layers.append( nn.ModuleList( [ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), ] ) ) def forward(self, x: torch.Tensor, *args, context=None, context_list=None): if context_list is None: context_list = [context] * len(self.layers) if len(context_list) != len(self.layers): raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") for i, (self_attn, cross_attn, ff) in enumerate(self.layers): x = self_attn(x, *args) + x x = cross_attn(x, *args, context=context_list[i]) + x x = ff(x, *args) + x return x class DropTokenDropout(nn.Module): def __init__(self, p: float = 0.1): super().__init__() if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, " "but got {}".format(p) ) self.p = p def forward(self, x: torch.Tensor): # x: (batch_size, seq_len, dim) if self.training and self.p > 0: zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() # TODO: permutation idx for each batch using torch.argsort if zero_mask.any(): x = x[:, ~zero_mask, :] return x class ZeroTokenDropout(nn.Module): def __init__(self, p: float = 0.1): super().__init__() if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, " "but got {}".format(p) ) self.p = p def forward(self, x: torch.Tensor): # x: (batch_size, seq_len, dim) if self.training and self.p > 0: zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() # Zero-out the masked tokens x[zero_mask, :] = 0 return x class TransformerEncoder(nn.Module): def __init__( self, num_tokens: int, token_dim: int, dim: int, depth: int, heads: int, mlp_dim: int, dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, emb_dropout_type: str = "drop", emb_dropout_loc: str = "token", norm: str = "layer", norm_cond_dim: int = -1, token_pe_numfreq: int = -1, ): super().__init__() if token_pe_numfreq > 0: token_dim_new = token_dim * (2 * token_pe_numfreq + 1) self.to_token_embedding = nn.Sequential( Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), nn.Linear(token_dim_new, dim), ) else: self.to_token_embedding = nn.Linear(token_dim, dim) self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) if emb_dropout_type == "drop": self.dropout = DropTokenDropout(emb_dropout) elif emb_dropout_type == "zero": self.dropout = ZeroTokenDropout(emb_dropout) else: raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") self.emb_dropout_loc = emb_dropout_loc self.transformer = Transformer( dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim ) def forward(self, inp: torch.Tensor, *args, **kwargs): x = inp if self.emb_dropout_loc == "input": x = self.dropout(x) x = self.to_token_embedding(x) if self.emb_dropout_loc == "token": x = self.dropout(x) b, n, _ = x.shape x += self.pos_embedding[:, :n] if self.emb_dropout_loc == "token_afterpos": x = self.dropout(x) x = self.transformer(x, *args) return x class TransformerDecoder(nn.Module): def __init__( self, num_tokens: int, token_dim: int, dim: int, depth: int, heads: int, mlp_dim: int, dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, emb_dropout_type: str = 'drop', norm: str = "layer", norm_cond_dim: int = -1, context_dim: Optional[int] = None, skip_token_embedding: bool = False, ): super().__init__() if not skip_token_embedding: self.to_token_embedding = nn.Linear(token_dim, dim) else: self.to_token_embedding = nn.Identity() if token_dim != dim: raise ValueError( f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" ) self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) if emb_dropout_type == "drop": self.dropout = DropTokenDropout(emb_dropout) elif emb_dropout_type == "zero": self.dropout = ZeroTokenDropout(emb_dropout) elif emb_dropout_type == "normal": self.dropout = nn.Dropout(emb_dropout) self.transformer = TransformerCrossAttn( dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim, context_dim=context_dim, ) def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): x = self.to_token_embedding(inp) b, n, _ = x.shape x = self.dropout(x) x += self.pos_embedding[:, :n] x = self.transformer(x, *args, context=context, context_list=context_list) return x