internlm-xcomposer2d5-ol-7b / memory /builder_projector.py
yhcao's picture
upload models
8e1010d
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
import torch.nn.functional as F
import re
import math
import numpy as np
import random
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class SlotAttention(nn.Module):
"""Slot Attention module."""
def __init__(self, num_slots, encoder_dims, iters=3, hidden_dim=128, out_dim=128, eps=1e-4):
"""Builds the Slot Attention module.
Args:
iters: Number of iterations.
num_slots: Number of slots.
encoder_dims: Dimensionality of slot feature vectors.
hidden_dim: Hidden layer size of MLP.
eps: Offset for attention coefficients before normalization.
"""
super(SlotAttention, self).__init__()
self.eps = eps
self.iters = iters
self.num_slots = num_slots
self.scale = encoder_dims ** -0.5
self.norm_input = nn.LayerNorm(encoder_dims)
self.norm_slots = nn.LayerNorm(encoder_dims)
self.norm_pre_ff = nn.LayerNorm(encoder_dims)
self.slots_embedding = nn.Parameter(torch.randn(1, num_slots, encoder_dims))
self.project_q = nn.Linear(encoder_dims, encoder_dims)
self.project_k = nn.Linear(encoder_dims, encoder_dims)
self.project_v = nn.Linear(encoder_dims, encoder_dims)
hidden_dim = max(encoder_dims, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(encoder_dims, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, encoder_dims)
)
self.head = nn.Linear(encoder_dims, out_dim)
def forward(self, inputs):
# inputs has shape [batch_size, num_inputs, inputs_size].
inputs = self.norm_input(inputs) # Apply layer norm to the input.
k = self.project_k(inputs) # Shape: [batch_size, num_inputs, slot_size].
v = self.project_v(inputs) # Shape: [batch_size, num_inputs, slot_size].
# Initialize the slots. Shape: [batch_size, num_slots, slot_size].
b, n, d = inputs.shape
n_s = self.num_slots
# learnable slots initializations
init_slots = self.slots_embedding.expand(b, -1, -1)
slots = init_slots
# Multiple rounds of attention.
for t in range(self.iters):
slots_prev = slots
slots = self.norm_slots(slots)
# Attention.
q = self.project_q(slots) # Shape: [batch_size, num_slots, slot_size].
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True) # weighted mean.
updates = torch.einsum('bjd,bij->bid', v, attn)
# `updates` has shape: [batch_size, num_slots, slot_size].
# Slot update.
slots = slots_prev + updates
slots = slots + self.mlp(self.norm_pre_ff(slots))
if t == self.iters-2:
slots = slots.detach() - init_slots.detach() + init_slots
output = self.head(slots)
return output
class PerceiverSampler(nn.Module):
def __init__(self, num_query_token, num_vision_features, out_size):
super(PerceiverSampler, self).__init__()
self.Qformer, self.query_tokens = self.init_qformer(
num_query_token, num_vision_features)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
self.ln_vision = nn.LayerNorm(num_vision_features)
self.head = nn.Linear(self.Qformer.config.hidden_size, out_size)
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.randn(1, num_query_token, encoder_config.hidden_size))
query_tokens.data.normal_(mean=0.0,
std=encoder_config.initializer_range)
return Qformer, query_tokens
def forward(self, inputs):
image_embeds = self.ln_vision(inputs)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(inputs.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
output = self.head(query_output.last_hidden_state)
return output
class MultiProjector(nn.Module):
def __init__(self, mlp, sa):
super(MultiProjector, self).__init__()
self.mlp = mlp
self.sa = sa
def forward(self, inputs, projector=None):
if not self.training:
output = self.mlp(inputs)
return output
idx_mlp = torch.where(projector==0)[0]
idx_sa = torch.where(projector==1)[0]
feat_mlp = self.mlp(inputs)
feat_sa = self.sa(inputs)
if len(idx_mlp) == 0:
projector[random.randint(0, projector.shape[0]-1), 0] = 0
elif len(idx_sa) == 0:
projector[random.randint(0, projector.shape[0]-1), 0] = 1
idx_mlp = torch.where(projector==0)[0]
idx_sa = torch.where(projector==1)[0]
output = []
for i in range(inputs.shape[0]):
if i in idx_mlp:
output.append(feat_mlp[i])
if i in idx_sa:
output.append(feat_sa[i])
assert len(output) == inputs.shape[0]
return output
class CompressProjector(nn.Module):
def __init__(self, mlp, num_slot, embed_dim):
super(CompressProjector, self).__init__()
self.mlp = mlp
self.num_slot = num_slot
self.query = nn.Parameter(torch.zeros(num_slot, embed_dim))
trunc_normal_(self.query, std=.02)
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
image_query = self.query.expand(concat_images.shape[0], -1, -1)
concat_images = torch.cat([concat_images, image_query], dim=1)
feature_query = self.query.expand(concat_features.shape[0], -1, -1)
concat_features = torch.cat([concat_features, feature_query], dim=1)
return concat_images, concat_features
output = self.mlp(inputs)
query = self.query.expand(output.shape[0], -1, -1)
output = torch.cat([output, query], dim=1)
return output
class PoolProjector(nn.Module):
def __init__(self, mlp, resolution, pool_num):
super(PoolProjector, self).__init__()
self.mlp = mlp
self.pool_num = pool_num
self.resolution = resolution
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
assert concat_images.shape[1] == self.resolution
h = int(np.sqrt(self.resolution))
grid = int(np.sqrt(self.pool_num))
n, k, c = concat_images.shape
image_maps = concat_images.view(n, h, h, c)
image_maps = image_maps.view(n, grid, h//grid, grid, h//grid, c)
image_maps = image_maps.permute(0, 1, 3, 2, 4, 5).contiguous()
image_maps = image_maps.view(n, self.pool_num, self.resolution//self.pool_num, c)
image_slot = torch.mean(image_maps, dim=-2)
image_global = torch.mean(concat_images, dim=1, keepdim=True)
n, k, c = concat_features.shape
video_maps = concat_features.view(n, k//self.resolution, h, h, c)
video_maps = video_maps.view(n, k//self.resolution, grid, h//grid, grid, h//grid, c)
video_maps = video_maps.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
video_maps = video_maps.view(n, k//self.resolution, self.pool_num, self.resolution//self.pool_num, c)
video_slot = torch.mean(video_maps, dim=-2).view(n, k//self.resolution*self.pool_num, c)
video_global = torch.mean(concat_features, dim=1, keepdim=True)
concat_images = torch.cat([concat_images, image_slot, image_global], dim=1) # stage 2 n k+1 c
concat_features = torch.cat([concat_features, video_slot, video_global], dim=1) # stage 2 n tk+t c
# concat_images = torch.cat([concat_images, image_slot], dim=1) # stage 1 n k c
# concat_features = torch.cat([concat_features, video_slot], dim=1) # stage 1 n tk c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
n, k, c = inputs.shape
h = int(np.sqrt(self.resolution))
grid = int(np.sqrt(self.pool_num))
maps = inputs.view(n, k//self.resolution, h, h, c)
maps = maps.view(n, k//self.resolution, grid, h//grid, grid, h//grid, c)
maps = maps.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
maps = maps.view(n, k//self.resolution, self.pool_num, self.resolution//self.pool_num, c)
slot = torch.mean(maps, dim=-2).view(n, k//self.resolution*self.pool_num, c)
global_pool = torch.mean(inputs, dim=1, keepdim=True)
output = self.mlp(torch.cat([inputs, slot, global_pool], dim=1)) # stage 2
# output = self.mlp(torch.cat([inputs, slot], dim=1)) # stage 1
return output
class BaseProjector(nn.Module):
def __init__(self, mlp):
super(BaseProjector, self).__init__()
self.mlp = mlp
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
time_token = torch.mean(concat_features, dim=2) # n t c
spatial_token = torch.mean(concat_features, dim=1) # n k c
concat_features = torch.cat([time_token, spatial_token], dim=1) # n t+k c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
if inputs.ndim == 3:
output = self.mlp(inputs)
return output
if inputs.ndim == 4:
time_token = torch.mean(inputs, dim=2) # n t c
spatial_token = torch.mean(inputs, dim=1) # n k c
token = torch.cat([time_token, spatial_token], dim=1)
output = self.mlp(token) # n t+k c
return output
class BaseMixProjector(nn.Module):
def __init__(self, mlp):
super(BaseMixProjector, self).__init__()
self.mlp = mlp
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
n, t, k, c = concat_features.shape
concat_features = concat_features.view(n, t*k, c) # n t*k c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
if inputs.ndim == 3:
output = self.mlp(inputs)
return output
if inputs.ndim == 4:
n, t, k, c = inputs.shape
token = inputs.view(n, t*k, c)
output = self.mlp(token) # n t*k c
return output
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
if projector_type == 'slot':
return SlotAttention(config.n_slot, config.mm_hidden_size, 3, config.hidden_size, config.hidden_size)
if projector_type == 'perceiver':
return PerceiverSampler(config.n_slot, config.mm_hidden_size, config.hidden_size)
if projector_type == 'mlpslot':
mlp_depth = 2
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
sa = SlotAttention(config.n_slot, config.mm_hidden_size, 3, config.hidden_size, config.hidden_size)
return MultiProjector(mlp, sa)
if projector_type == 'compress':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return CompressProjector(mlp, config.n_slot, config.hidden_size)
if projector_type == 'pool':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
pool_num = config.pool_num if hasattr(config, 'pool_num') else 1
return PoolProjector(mlp, config.resolution, pool_num)
if projector_type == 'base':
mlp_depth = 2
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return BaseProjector(mlp)
if projector_type == 'base_mix':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return BaseMixProjector(mlp)
raise ValueError(f'Unknown projector type: {projector_type}')