File size: 17,392 Bytes
8e1010d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 |
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
import torch.nn.functional as F
import re
import math
import numpy as np
import random
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class SlotAttention(nn.Module):
"""Slot Attention module."""
def __init__(self, num_slots, encoder_dims, iters=3, hidden_dim=128, out_dim=128, eps=1e-4):
"""Builds the Slot Attention module.
Args:
iters: Number of iterations.
num_slots: Number of slots.
encoder_dims: Dimensionality of slot feature vectors.
hidden_dim: Hidden layer size of MLP.
eps: Offset for attention coefficients before normalization.
"""
super(SlotAttention, self).__init__()
self.eps = eps
self.iters = iters
self.num_slots = num_slots
self.scale = encoder_dims ** -0.5
self.norm_input = nn.LayerNorm(encoder_dims)
self.norm_slots = nn.LayerNorm(encoder_dims)
self.norm_pre_ff = nn.LayerNorm(encoder_dims)
self.slots_embedding = nn.Parameter(torch.randn(1, num_slots, encoder_dims))
self.project_q = nn.Linear(encoder_dims, encoder_dims)
self.project_k = nn.Linear(encoder_dims, encoder_dims)
self.project_v = nn.Linear(encoder_dims, encoder_dims)
hidden_dim = max(encoder_dims, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(encoder_dims, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, encoder_dims)
)
self.head = nn.Linear(encoder_dims, out_dim)
def forward(self, inputs):
# inputs has shape [batch_size, num_inputs, inputs_size].
inputs = self.norm_input(inputs) # Apply layer norm to the input.
k = self.project_k(inputs) # Shape: [batch_size, num_inputs, slot_size].
v = self.project_v(inputs) # Shape: [batch_size, num_inputs, slot_size].
# Initialize the slots. Shape: [batch_size, num_slots, slot_size].
b, n, d = inputs.shape
n_s = self.num_slots
# learnable slots initializations
init_slots = self.slots_embedding.expand(b, -1, -1)
slots = init_slots
# Multiple rounds of attention.
for t in range(self.iters):
slots_prev = slots
slots = self.norm_slots(slots)
# Attention.
q = self.project_q(slots) # Shape: [batch_size, num_slots, slot_size].
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True) # weighted mean.
updates = torch.einsum('bjd,bij->bid', v, attn)
# `updates` has shape: [batch_size, num_slots, slot_size].
# Slot update.
slots = slots_prev + updates
slots = slots + self.mlp(self.norm_pre_ff(slots))
if t == self.iters-2:
slots = slots.detach() - init_slots.detach() + init_slots
output = self.head(slots)
return output
class PerceiverSampler(nn.Module):
def __init__(self, num_query_token, num_vision_features, out_size):
super(PerceiverSampler, self).__init__()
self.Qformer, self.query_tokens = self.init_qformer(
num_query_token, num_vision_features)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
self.ln_vision = nn.LayerNorm(num_vision_features)
self.head = nn.Linear(self.Qformer.config.hidden_size, out_size)
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.randn(1, num_query_token, encoder_config.hidden_size))
query_tokens.data.normal_(mean=0.0,
std=encoder_config.initializer_range)
return Qformer, query_tokens
def forward(self, inputs):
image_embeds = self.ln_vision(inputs)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(inputs.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
output = self.head(query_output.last_hidden_state)
return output
class MultiProjector(nn.Module):
def __init__(self, mlp, sa):
super(MultiProjector, self).__init__()
self.mlp = mlp
self.sa = sa
def forward(self, inputs, projector=None):
if not self.training:
output = self.mlp(inputs)
return output
idx_mlp = torch.where(projector==0)[0]
idx_sa = torch.where(projector==1)[0]
feat_mlp = self.mlp(inputs)
feat_sa = self.sa(inputs)
if len(idx_mlp) == 0:
projector[random.randint(0, projector.shape[0]-1), 0] = 0
elif len(idx_sa) == 0:
projector[random.randint(0, projector.shape[0]-1), 0] = 1
idx_mlp = torch.where(projector==0)[0]
idx_sa = torch.where(projector==1)[0]
output = []
for i in range(inputs.shape[0]):
if i in idx_mlp:
output.append(feat_mlp[i])
if i in idx_sa:
output.append(feat_sa[i])
assert len(output) == inputs.shape[0]
return output
class CompressProjector(nn.Module):
def __init__(self, mlp, num_slot, embed_dim):
super(CompressProjector, self).__init__()
self.mlp = mlp
self.num_slot = num_slot
self.query = nn.Parameter(torch.zeros(num_slot, embed_dim))
trunc_normal_(self.query, std=.02)
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
image_query = self.query.expand(concat_images.shape[0], -1, -1)
concat_images = torch.cat([concat_images, image_query], dim=1)
feature_query = self.query.expand(concat_features.shape[0], -1, -1)
concat_features = torch.cat([concat_features, feature_query], dim=1)
return concat_images, concat_features
output = self.mlp(inputs)
query = self.query.expand(output.shape[0], -1, -1)
output = torch.cat([output, query], dim=1)
return output
class PoolProjector(nn.Module):
def __init__(self, mlp, resolution, pool_num):
super(PoolProjector, self).__init__()
self.mlp = mlp
self.pool_num = pool_num
self.resolution = resolution
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
assert concat_images.shape[1] == self.resolution
h = int(np.sqrt(self.resolution))
grid = int(np.sqrt(self.pool_num))
n, k, c = concat_images.shape
image_maps = concat_images.view(n, h, h, c)
image_maps = image_maps.view(n, grid, h//grid, grid, h//grid, c)
image_maps = image_maps.permute(0, 1, 3, 2, 4, 5).contiguous()
image_maps = image_maps.view(n, self.pool_num, self.resolution//self.pool_num, c)
image_slot = torch.mean(image_maps, dim=-2)
image_global = torch.mean(concat_images, dim=1, keepdim=True)
n, k, c = concat_features.shape
video_maps = concat_features.view(n, k//self.resolution, h, h, c)
video_maps = video_maps.view(n, k//self.resolution, grid, h//grid, grid, h//grid, c)
video_maps = video_maps.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
video_maps = video_maps.view(n, k//self.resolution, self.pool_num, self.resolution//self.pool_num, c)
video_slot = torch.mean(video_maps, dim=-2).view(n, k//self.resolution*self.pool_num, c)
video_global = torch.mean(concat_features, dim=1, keepdim=True)
concat_images = torch.cat([concat_images, image_slot, image_global], dim=1) # stage 2 n k+1 c
concat_features = torch.cat([concat_features, video_slot, video_global], dim=1) # stage 2 n tk+t c
# concat_images = torch.cat([concat_images, image_slot], dim=1) # stage 1 n k c
# concat_features = torch.cat([concat_features, video_slot], dim=1) # stage 1 n tk c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
n, k, c = inputs.shape
h = int(np.sqrt(self.resolution))
grid = int(np.sqrt(self.pool_num))
maps = inputs.view(n, k//self.resolution, h, h, c)
maps = maps.view(n, k//self.resolution, grid, h//grid, grid, h//grid, c)
maps = maps.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
maps = maps.view(n, k//self.resolution, self.pool_num, self.resolution//self.pool_num, c)
slot = torch.mean(maps, dim=-2).view(n, k//self.resolution*self.pool_num, c)
global_pool = torch.mean(inputs, dim=1, keepdim=True)
output = self.mlp(torch.cat([inputs, slot, global_pool], dim=1)) # stage 2
# output = self.mlp(torch.cat([inputs, slot], dim=1)) # stage 1
return output
class BaseProjector(nn.Module):
def __init__(self, mlp):
super(BaseProjector, self).__init__()
self.mlp = mlp
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
time_token = torch.mean(concat_features, dim=2) # n t c
spatial_token = torch.mean(concat_features, dim=1) # n k c
concat_features = torch.cat([time_token, spatial_token], dim=1) # n t+k c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
if inputs.ndim == 3:
output = self.mlp(inputs)
return output
if inputs.ndim == 4:
time_token = torch.mean(inputs, dim=2) # n t c
spatial_token = torch.mean(inputs, dim=1) # n k c
token = torch.cat([time_token, spatial_token], dim=1)
output = self.mlp(token) # n t+k c
return output
class BaseMixProjector(nn.Module):
def __init__(self, mlp):
super(BaseMixProjector, self).__init__()
self.mlp = mlp
def forward(self, inputs, projector=None):
if type(inputs) is list:
concat_images, concat_features = inputs
n, t, k, c = concat_features.shape
concat_features = concat_features.view(n, t*k, c) # n t*k c
concat_combine = torch.cat(
[concat_images.reshape(-1, concat_images.shape[-1]), concat_features.reshape(-1, concat_features.shape[-1])], dim=0)
concat_combine = self.mlp(concat_combine)
concat_images = concat_combine[:concat_images.shape[0]*concat_images.shape[1]].contiguous().view(*concat_images.shape[:2], -1)
concat_features = concat_combine[concat_images.shape[0]*concat_images.shape[1]:].contiguous().view(*concat_features.shape[:2], -1)
return concat_images, concat_features
if inputs.ndim == 3:
output = self.mlp(inputs)
return output
if inputs.ndim == 4:
n, t, k, c = inputs.shape
token = inputs.view(n, t*k, c)
output = self.mlp(token) # n t*k c
return output
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
if projector_type == 'slot':
return SlotAttention(config.n_slot, config.mm_hidden_size, 3, config.hidden_size, config.hidden_size)
if projector_type == 'perceiver':
return PerceiverSampler(config.n_slot, config.mm_hidden_size, config.hidden_size)
if projector_type == 'mlpslot':
mlp_depth = 2
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
sa = SlotAttention(config.n_slot, config.mm_hidden_size, 3, config.hidden_size, config.hidden_size)
return MultiProjector(mlp, sa)
if projector_type == 'compress':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return CompressProjector(mlp, config.n_slot, config.hidden_size)
if projector_type == 'pool':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
pool_num = config.pool_num if hasattr(config, 'pool_num') else 1
return PoolProjector(mlp, config.resolution, pool_num)
if projector_type == 'base':
mlp_depth = 2
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return BaseProjector(mlp)
if projector_type == 'base_mix':
mlp_depth = 2
modules = [nn.Linear(4*config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mlp = nn.Sequential(*modules)
return BaseMixProjector(mlp)
raise ValueError(f'Unknown projector type: {projector_type}')
|