Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from typing import Tuple, Union | |
import logging | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from timm.models.layers import DropPath, trunc_normal_ | |
from .backbone import Backbone | |
from .build import BACKBONE_REGISTRY | |
from .det_swin import SwinTransformer | |
from ..text_encoder import build_text_encoder | |
from ..text_encoder import build_tokenizer | |
class LayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-12): | |
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |
""" | |
super(LayerNorm, self).__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, x): | |
pdtype = x.dtype | |
x = x.float() | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.weight * x.to(pdtype) + self.bias | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, | |
d_model: int, | |
n_head: int, | |
attn_mask: torch.Tensor = None, | |
drop_path: float = 0.0): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
def attention(self, x: torch.Tensor): | |
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ | |
if self.attn_mask is not None else None | |
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
def forward(self, x: torch.Tensor): | |
x = x + self.drop_path(self.attention(self.ln_1(x))) | |
x = x + self.drop_path(self.mlp(self.ln_2(x))) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, | |
context_length: int, | |
vocab_size: int, | |
width: int, | |
layers: int, | |
heads: int, | |
drop_path: float = 0.0): | |
super().__init__() | |
self.token_embedding = nn.Embedding(vocab_size, width) | |
self.context_length = context_length | |
self.positional_embedding = nn.Parameter( | |
torch.empty(self.context_length, width) | |
) | |
self.width = width | |
self.layers = layers | |
attn_mask = self.build_attention_mask() | |
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule | |
self.resblocks = nn.Sequential( | |
*[ | |
ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) | |
for i in range(layers) | |
] | |
) | |
self.ln_final = LayerNorm(width) | |
trunc_normal_(self.positional_embedding, std=.02) | |
# nn.init.normal_(self.token_embedding, std=.02) | |
trunc_normal_(self.token_embedding.weight, std=.02) | |
self.apply(self._init_weights) | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Linear, nn.Conv2d)): | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): | |
nn.init.constant_(m.bias, 0) | |
def no_weight_decay(self): | |
return { | |
'positional_embedding', | |
'token_embedding', | |
} | |
def forward(self, text: torch.Tensor): | |
x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.resblocks(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] | |
return x | |
class CLIP(Backbone): | |
def __init__(self, config: dict): | |
super().__init__() | |
spec_text = config['MODEL']['SPEC']['TEXT'] | |
assert spec_text['TOKENIZER'] == 'clip', 'Only support clip tokenizer' | |
self.tokenizer_style = spec_text['TOKENIZER'] | |
self.tokenizer = build_tokenizer(spec_text) | |
self.text_encoder = build_text_encoder(spec_text, self.tokenizer, True) | |
embed_dim = config['MODEL']['SPEC']['EMBED_DIM'] | |
self.text_projection = nn.Parameter( | |
torch.empty(spec_text['WIDTH'], embed_dim) | |
) | |
spec_vision = config['MODEL']['SPEC']['VISION'] | |
self.image_encoder = SwinTransformer( | |
patch_size=spec_vision['PATCH_SIZE'], | |
in_chans=spec_vision['IN_CHANS'], | |
embed_dim=spec_vision['EMBED_DIM'], | |
depths=spec_vision['DEPTHS'], | |
num_heads=spec_vision['NUM_HEADS'], | |
window_size=spec_vision['WINDOW_SIZE'], | |
mlp_ratio=spec_vision['MLP_RATIO'], | |
qkv_bias=spec_vision['QKV_BIAS'], | |
qk_scale=spec_vision.get('QK_SCALE', None), | |
drop_rate=spec_vision['DROP_RATE'], | |
attn_drop_rate=spec_vision['ATTN_DROP_RATE'], | |
drop_path_rate=spec_vision['DROP_PATH_RATE'], | |
ape=spec_vision['APE'], | |
patch_norm=spec_vision['PATCH_NORM'], | |
out_indices=(0, 1, 2, 3), | |
frozen_stages=-1, | |
use_checkpoint=False, | |
) | |
width = spec_vision['EMBED_DIM'] * 2 ** (len(spec_vision['DEPTHS']) - 1) | |
self.image_projection = nn.Parameter( | |
torch.empty(width, embed_dim) | |
) | |
# self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1 / 0.07)])) | |
self.logit_scale = nn.Parameter(torch.ones([])) | |
trunc_normal_(self.text_projection, std=.02) | |
trunc_normal_(self.image_projection, std=.02) | |
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): | |
if os.path.isfile(pretrained): | |
pretrained_dict = torch.load(pretrained, map_location='cpu') | |
logger.info(f'=> loading pretrained model {pretrained}') | |
model_dict = self.state_dict() | |
pretrained_dict = { | |
k: v for k, v in pretrained_dict.items() | |
if k in model_dict.keys() | |
} | |
need_init_state_dict = {} | |
for k, v in pretrained_dict.items(): | |
need_init = ( | |
k.split('.')[0] in pretrained_layers | |
or pretrained_layers[0] is '*' | |
) | |
if need_init: | |
if verbose: | |
logging.info(f'=> init {k} from {pretrained}') | |
need_init_state_dict[k] = v | |
self.load_state_dict(need_init_state_dict, strict=False) | |
def no_weight_decay(self): | |
no_weight_decay = {'logit_scale'} | |
for k in self.text_encoder.no_weight_decay(): | |
no_weight_decay.add('text.'+k) | |
for k in self.image_encoder.no_weight_decay(): | |
no_weight_decay.add('visual.'+k) | |
return no_weight_decay | |
def no_weight_decay_keywords(self): | |
return {'relative_position_bias_table'} | |
def dtype(self): | |
return self.image_encoder.conv1.weight.dtype | |
def encode_image(self, image, norm=True): | |
x = self.image_encoder(image) | |
return x | |
def encode_text(self, text, norm=True): | |
assert isinstance(text, str), "only support single query" | |
tokens = self.tokenizer( | |
text, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
) | |
tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
x = self.text_encoder(**tokens) | |
x = x['last_hidden_state'] | |
x = x[torch.arange(x.size(0)), tokens['input_ids'].argmax(dim=-1)] | |
x = x @ self.text_projection | |
if norm: | |
x = x / x.norm(dim=-1, keepdim=True) | |
return x | |
def forward(self, image): | |
features_image = self.image_encoder(image) | |
return features_image | |
def build_clip_swin_backbone(cfg, input_shape): | |
""" | |
Create a CLIP Swin instance from config. | |
Returns: | |
SwinTransformer: a :class:`SwinTransformer` instance. | |
""" | |
spec_vision = cfg.MODEL.CLIP.VISION | |
return SwinTransformer( | |
patch_size=spec_vision['PATCH_SIZE'], | |
in_chans=spec_vision['IN_CHANS'], | |
embed_dim=spec_vision['EMBED_DIM'], | |
depths=spec_vision['DEPTHS'], | |
num_heads=spec_vision['NUM_HEADS'], | |
window_size=spec_vision['WINDOW_SIZE'], | |
mlp_ratio=spec_vision['MLP_RATIO'], | |
qkv_bias=spec_vision['QKV_BIAS'], | |
qk_scale=spec_vision.get('QK_SCALE', None), | |
drop_rate=spec_vision['DROP_RATE'], | |
attn_drop_rate=spec_vision['ATTN_DROP_RATE'], | |
drop_path_rate=spec_vision['DROP_PATH_RATE'], | |
ape=spec_vision['APE'], | |
patch_norm=spec_vision['PATCH_NORM'], | |
out_indices=(0, 1, 2, 3), | |
frozen_stages=-1, | |
use_checkpoint=False, | |
) | |
def build_clip_swin(cfg, input_shape): | |
""" | |
Create a CLIP Swin instance from config. | |
Returns: | |
SwinTransformer: a :class:`SwinTransformer` instance. | |
""" | |
return CLIP(cfg) |