Text2Cryptopunks / text2punks /transformer.py
Khalil
First commit, add text2punps scripts, app file, and requirements file
b41a54a
from functools import partial
from itertools import islice, cycle
from torch import nn
from text2punks.attention import Attention, SparseAxialCausalAttention
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
# classes
class SequentialSequence(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(self, x):
for (f, g) in list(self.layers):
x = x + f(x)
x = x + g(x)
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim)
)
# the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop)
def forward(self, x):
return self.net(x)
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
causal = True,
heads = 8,
dim_head = 64,
attn_dropout = 0.,
resid_dropout = 0.,
embd_dropout = 0.,
ff_dropout = 0.,
image_size = 24,
attn_types = None,
):
super().__init__()
layers = nn.ModuleList([])
attn_types = default(attn_types, ('full',))
attn_types = cast_tuple(attn_types)
attn_type_layer = islice(cycle(attn_types), depth)
for attn_type in attn_type_layer:
if attn_type == 'full':
attn_class = partial(Attention, causal = causal)
elif attn_type == 'axial_row':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size)
elif attn_type == 'axial_col':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size)
else:
raise ValueError(f'attention type "{attn_type}" is not valid')
attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
layers.append(nn.ModuleList([
PreNorm(dim, attn),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
]))
# full attention in the last layer
attn_class = partial(Attention, causal = causal)
attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
layers.append(nn.ModuleList([
PreNorm(dim, attn),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
]))
self.layers = SequentialSequence(layers)
self.embd_drop = nn.Dropout(embd_dropout)
def forward(self, x):
x = self.embd_drop(x)
return self.layers(x)