Khalil
First commit, add text2punps scripts, app file, and requirements file
b41a54a
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
def max_neg_value(t):
return -torch.finfo(t.dtype).max
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.causal = causal
self.attn_drop = nn.Dropout(attn_dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(resid_dropout)
)
def forward(self, x):
h, device = self.heads, x.device
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = torch.softmax(dots, dim=-1)
attn = self.attn_drop(attn)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# sparse axial causal attention
class SparseAxialCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
super().__init__()
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
self.axis = axis
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.attn_drop = nn.Dropout(attn_dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(resid_dropout)
)
def forward(self, x):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
# padding
padding = seq_len - n + 1
mask = torch.ones(b, text_len, device = device).bool()
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
# derive queries / keys / values
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
# print(self.scale)
q = q * self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
# text attention
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = torch.softmax(dots_text, dim = -1)
# attention dropout
attn_text = self.attn_drop(attn_text)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
# image attention
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
# split out axis
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
# similarity
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
# mask so image has full attention to text, but causal along axis
bh, x, i, j = dots.shape
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
mask = torch.cat((~mask, causal_mask), dim = -1)
dots.masked_fill_(mask, mask_value)
# attention.
attn = torch.softmax(dots, dim = -1)
# attention dropout
attn = self.attn_drop(attn)
# aggregate
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
# merge back axis
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
# combine attended values for both text and image
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]