from functools import partial from itertools import islice, cycle from torch import nn from text2punks.attention import Attention, SparseAxialCausalAttention # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cast_tuple(val, depth = 1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else (val,) * depth # classes class SequentialSequence(nn.Module): def __init__(self, layers): super().__init__() self.layers = layers def forward(self, x): for (f, g) in list(self.layers): x = x + f(x) x = x + g(x) return x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim) ) # the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop) def forward(self, x): return self.net(x) class Transformer(nn.Module): def __init__( self, *, dim, depth, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0., embd_dropout = 0., ff_dropout = 0., image_size = 24, attn_types = None, ): super().__init__() layers = nn.ModuleList([]) attn_types = default(attn_types, ('full',)) attn_types = cast_tuple(attn_types) attn_type_layer = islice(cycle(attn_types), depth) for attn_type in attn_type_layer: if attn_type == 'full': attn_class = partial(Attention, causal = causal) elif attn_type == 'axial_row': attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size) elif attn_type == 'axial_col': attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size) else: raise ValueError(f'attention type "{attn_type}" is not valid') attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout) layers.append(nn.ModuleList([ PreNorm(dim, attn), PreNorm(dim, FeedForward(dim, dropout = ff_dropout)) ])) # full attention in the last layer attn_class = partial(Attention, causal = causal) attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout) layers.append(nn.ModuleList([ PreNorm(dim, attn), PreNorm(dim, FeedForward(dim, dropout = ff_dropout)) ])) self.layers = SequentialSequence(layers) self.embd_drop = nn.Dropout(embd_dropout) def forward(self, x): x = self.embd_drop(x) return self.layers(x)