#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright 2020 Ross Wightman # Modified Model definition """Video models.""" import math import torch import torch.nn as nn from einops import rearrange, repeat from timm.layers import to_2tuple from torch import einsum from torch.nn import functional as F default_cfgs = { 'vit_1k': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 'vit_1k_large': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', } def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): sim = einsum('b i d, b j d -> b i j', q, k) # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) if tok_mask is not None: BSH, N = tok_mask.shape sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float('-inf')) # 1 - broadcasts across N attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) return out class DividedAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) # init to zeros self.qkv.weight.data.fill_(0) self.qkv.bias.data.fill_(0) self.proj.weight.data.fill_(1) self.proj.bias.data.fill_(0) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): # num of heads variable h = self.num_heads # project x to q, k, v vaalues q, k, v = self.qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) if tok_mask is not None: # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d assert len(tok_mask.shape) == 2 tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) # Scale q q *= self.scale # Take out cls_q, cls_k, cls_v (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) # the same for masking if tok_mask is not None: cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] else: cls_mask, mask_ = None, None # let CLS token attend to key / values of all patches across time and space cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) # rearrange across time or space q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) # expand CLS token keys and values across time or space and concat r = q_.shape[0] // cls_k.shape[0] cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) k_ = torch.cat((cls_k, k_), dim=1) v_ = torch.cat((cls_v, v_), dim=1) # the same for masking (if provided) if tok_mask is not None: # since mask does not have the latent dim (d), we need to remove it from einops dims mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''), **einops_dims) cls_mask = repeat(cls_mask, 'b () -> (b r) ()', r=r) # expand cls_mask across time or space mask_ = torch.cat((cls_mask, mask_), dim=1) # attention out = qkv_attn(q_, k_, v_, tok_mask=mask_) # merge back time or space out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # concat back the cls token out = torch.cat((cls_out, out), dim=1) # merge back the heads out = rearrange(out, '(b h) n d -> b n (h d)', h=h) ## to out x = self.proj(out) x = self.proj_drop(x) return x class DividedSpaceTimeBlock(nn.Module): def __init__(self, dim=768, num_heads=12, attn_type='divided', mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.einops_from_space = 'b (f n) d' self.einops_to_space = '(b f) n d' self.einops_from_time = 'b (f n) d' self.einops_to_time = '(b n) f d' self.norm1 = norm_layer(dim) self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.timeattn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.norm3 = norm_layer(dim) def forward(self, x, seq_len=196, num_frames=8, approx='none', num_landmarks=128, tok_mask: torch.Tensor = None): time_output = self.timeattn(self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask) time_residual = x + time_output space_output = self.attn(self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask) space_residual = time_residual + self.drop_path(space_output) x = space_residual x = x + self.drop_path(self.mlp(self.norm2(x))) return x class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) return x class PatchEmbed3D(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, temporal_resolution=4, in_chans=3, patch_size=16, z_block_size=2, embed_dim=768, flatten=True): super().__init__() self.height = (img_size // patch_size) self.width = (img_size // patch_size) ### v-iashin: these two are incorrect # self.frames = (temporal_resolution // z_block_size) # self.num_patches = self.height * self.width * self.frames self.z_block_size = z_block_size ### self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(z_block_size, patch_size, patch_size), stride=(z_block_size, patch_size, patch_size)) self.flatten = flatten def forward(self, x): B, C, T, H, W = x.shape x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) return x class HeadMLP(nn.Module): def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): super(HeadMLP, self).__init__() self.n_input = n_input self.n_classes = n_classes self.n_hidden = n_hidden if n_hidden is None: # use linear classifier self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) else: # use simple MLP classifier self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_hidden, bias=True), nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), nn.Dropout(p=p), nn.Linear(n_hidden, n_classes, bias=True)) print(f"Dropout-NLP: {p}") def forward(self, x): return self.block_forward(x) def _conv_filter(state_dict, patch_size=16): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: v = v.reshape((v.shape[0], 3, patch_size, patch_size)) out_dict[k] = v return out_dict def adapt_input_conv(in_chans, conv_weight, agg='sum'): conv_type = conv_weight.dtype conv_weight = conv_weight.float() O, I, J, K = conv_weight.shape if in_chans == 1: if I > 3: assert conv_weight.shape[1] % 3 == 0 # For models with space2depth stems conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) conv_weight = conv_weight.sum(dim=2, keepdim=False) else: if agg == 'sum': print("Summing conv1 weights") conv_weight = conv_weight.sum(dim=1, keepdim=True) else: print("Averaging conv1 weights") conv_weight = conv_weight.mean(dim=1, keepdim=True) elif in_chans != 3: if I != 3: raise NotImplementedError('Weight format not supported by conversion.') else: if agg == 'sum': print("Summing conv1 weights") repeat = int(math.ceil(in_chans / 3)) conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv_weight *= (3 / float(in_chans)) else: print("Averaging conv1 weights") conv_weight = conv_weight.mean(dim=1, keepdim=True) conv_weight = conv_weight.repeat(1, in_chans, 1, 1) conv_weight = conv_weight.to(conv_type) return conv_weight def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): # Load state dict assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]") state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) if filter_fn is not None: state_dict = filter_fn(state_dict) input_convs = 'patch_embed.proj' if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs, ) for input_conv_name in input_convs: weight_name = input_conv_name + '.weight' try: state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg='avg') print( f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)' ) except NotImplementedError as e: del state_dict[weight_name] strict = False print( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.' ) classifier_name = 'head' label_offset = cfg.get('label_offset', 0) pretrain_classes = 1000 if num_classes != pretrain_classes: # completely discard fully connected if model num_classes doesn't match pretrained weights del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.bias'] strict = False elif label_offset > 0: # special case for pretrained weights with an extra background class in pretrained weights classifier_weight = state_dict[classifier_name + '.weight'] state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] classifier_bias = state_dict[classifier_name + '.bias'] state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] loaded_state = state_dict self_state = model.state_dict() all_names = set(self_state.keys()) saved_names = set([]) for name, param in loaded_state.items(): param = param if 'module.' in name: name = name.replace('module.', '') if name in self_state.keys() and param.shape == self_state[name].shape: saved_names.add(name) self_state[name].copy_(param) else: print(f"didnt load: {name} of shape: {param.shape}") print("Missing Keys:") print(all_names - saved_names)