gecko / model /multimodal_encoder.py
farrosalferro24's picture
Initial commit
09773e9 verified
raw
history blame
6.02 kB
# This code is referenced from https://github.com/dhansmair/flamingo-mini
import torch
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum, nn
import math
import torch.nn.functional as F
from .configuration_gecko import GeckoConfig
from transformers.activations import ACT2FN
from torch.nn.init import trunc_normal_
from functools import partial
def feed_forward_layer(dim: int, mult: int = 4, activation: str = 'gelu'):
"""Feed forward layer with given activation function"""
activations = dict(gelu=nn.GELU, relu=nn.ReLU)
assert activation in activations, f'activation can only be one of {activations.keys()}'
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
activations[activation](),
nn.Linear(inner_dim, dim, bias=False),
)
class PerceiverAttentionLayer(nn.Module):
"""Perceiver Attention Layer"""
def __init__(self, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
# trainable components of PerceiverAttentionLayer
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, features, latents):
"""Latent vectors are cross-attending to the visual features x
Args:
features: Batch of visual features with shape (batch_size, n_tokens, dim)
latents: Latent learnt vectors which are used to compute queries with shape (batch_size, n_latents, dim)
Returns:
Attention score with shape (batch_size, n_latents, dim)
"""
assert features.ndim == 3
assert latents.ndim == 3
assert features.shape[0] == latents.shape[0]
assert features.shape[2] == latents.shape[2]
n_heads = self.heads
n_batch, n_features, dim = features.shape
n_queries = latents.shape[1]
# Layer normalization
x = self.norm_media(features)
latents = self.norm_latents(latents)
# Compute the queries from the latents, for all attention heads simultaneously
q = self.to_q(latents)
q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])
# Keys and values for all attention heads
kv_input = torch.cat((x, latents), dim=-2)
n_features_latents = n_features + n_queries
k = self.to_k(kv_input)
v = self.to_v(kv_input)
k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])
q = q * self.scale
# Attention scores
sim = einsum('b h q d, b h f d -> b h q f', q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
alphas = sim.softmax(dim=-1)
out = einsum('b h q f, b h f v -> b h q v', alphas, v)
out = rearrange(out, 'b h q v -> b q (h v)')
return self.to_out(out)
class GeckoResamplerProjector(nn.Module):
"""Perceiver Resampler with multi-head attention layer"""
def __init__(
self,
config: GeckoConfig,
num_queries: int = 64,
depth: int = 2,
dim_head: int = 32,
heads: int = 4,
ff_mult: int = 2,
):
super().__init__()
self.dim = config.text_config.hidden_size
self.num_queries = num_queries
self.latents = nn.Parameter(torch.randn(self.num_queries, self.dim)) # type: ignore[reportPrivateUsage]
self.linear = nn.Linear(config.vision_config.hidden_size, self.dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttentionLayer(dim=self.dim, dim_head=dim_head, heads=heads),
feed_forward_layer(dim=self.dim, mult=ff_mult, activation=config.projector_hidden_act),
]
)
)
# Layer normalization takes as input the query vector length
self.norm = nn.LayerNorm(self.dim)
def forward(self, x_f: torch.Tensor):
"""Run perceiver resampler on the input visual embeddings
Args:
x_f: Input visual embeddings of shape (batch_size, num_tokens, d_visual)
Returns:
Resampler features of shape (batch_size, num_queries, d_visual)
"""
assert x_f.ndim == 3
x_f = self.linear(x_f)
batch_size, num_tokens, dim = x_f.shape
assert dim == self.dim
# Copy the latents for every element in the batch
x = repeat(self.latents, 'q d -> b q d', b=batch_size)
# Apply attention and feed forward layer
for attn, ffw in self.layers:
x = x + attn(x_f, x)
x = x + ffw(x)
assert x.shape == torch.Size([batch_size, self.num_queries, self.dim])
norm = self.norm(x)
return norm
class GeckoMLPProjector(nn.Module):
def __init__(self, config: GeckoConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states