import math from functools import reduce from operator import mul from ipdb import set_trace import torch import torch.nn.functional as F import torch.nn as nn from mmcls.models.backbones import VisionTransformer as _VisionTransformer from mmcls.models.utils import to_2tuple from mmcv.cnn.bricks.transformer import PatchEmbed from torch.nn.modules.batchnorm import _BatchNorm def build_2d_sincos_position_embedding(patches_resolution, embed_dims, temperature=10000., cls_token=False): """The function is to build position embedding for model to obtain the position information of the image patches.""" if isinstance(patches_resolution, int): patches_resolution = (patches_resolution, patches_resolution) h, w = patches_resolution grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dims % 4 == 0, \ 'Embed dimension must be divisible by 4.' pos_dim = embed_dims // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature**omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat( [ torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h) ], dim=1, )[None, :, :] if cls_token: cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) return pos_emb class VisionTransformer(_VisionTransformer): """Vision Transformer. A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Part of the code is modified from: ``_. Args: stop_grad_conv1 (bool, optional): whether to stop the gradient of convolution layer in `PatchEmbed`. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['mocov3-s', 'mocov3-small'], { 'embed_dims': 384, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 1536, }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072 }), } def __init__(self, stop_grad_conv1=False, frozen_stages=-1, norm_eval=False, init_cfg=None, **kwargs): super(VisionTransformer, self).__init__(init_cfg=init_cfg,) self.patch_size = kwargs['patch_size'] self.frozen_stages = frozen_stages self.norm_eval = norm_eval self.init_cfg = init_cfg if isinstance(self.patch_embed, PatchEmbed): if stop_grad_conv1: self.patch_embed.projection.weight.requires_grad = False self.patch_embed.projection.bias.requires_grad = False self._freeze_stages() def init_weights(self): super(VisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Use fixed 2D sin-cos position embedding pos_emb = build_2d_sincos_position_embedding( patches_resolution=self.patch_resolution, embed_dims=self.embed_dims, cls_token=True) self.pos_embed.data.copy_(pos_emb) self.pos_embed.requires_grad = False # xavier_uniform initialization for PatchEmbed if isinstance(self.patch_embed, PatchEmbed): val = math.sqrt( 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + self.embed_dims)) nn.init.uniform_(self.patch_embed.projection.weight, -val, val) nn.init.zeros_(self.patch_embed.projection.bias) # initialization for linear layers for name, m in self.named_modules(): if isinstance(m, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt( 6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) else: nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) nn.init.normal_(self.cls_token, std=1e-6) def _freeze_stages(self): """Freeze patch_embed layer, some parameters and stages.""" if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False self.cls_token.requires_grad = False self.pos_embed.requires_grad = False for i in range(1, self.frozen_stages + 1): m = self.layers[i - 1] m.eval() for param in m.parameters(): param.requires_grad = False if i == (self.num_layers) and self.final_norm: for param in getattr(self, 'norm1').parameters(): param.requires_grad = False def train(self, mode=True): super(VisionTransformer, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()