diff --git a/model/CoordAttention.py b/model/CoordAttention.py new file mode 100644 index 0000000000000000000000000000000000000000..bc82722043b4c5ac8909f4cd3059920b81366134 --- /dev/null +++ b/model/CoordAttention.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + + def forward(self, x): + return self.relu(x + 3) / 6 + + +class h_swish(nn.Module): + def __init__(self, inplace=True): + super(h_swish, self).__init__() + self.sigmoid = h_sigmoid(inplace=inplace) + + def forward(self, x): + return x * self.sigmoid(x) + + +class CoordAtt(nn.Module): + def __init__(self, inp, oup, reduction=32): + super(CoordAtt, self).__init__() + self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) + self.pool_w = nn.AdaptiveAvgPool2d((1, None)) + + mip = max(8, inp // reduction) + + self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) + self.bn1 = nn.BatchNorm2d(mip) + + self.bn2 = nn.BatchNorm2d(1) + self.bn3 = nn.BatchNorm2d(1) + self.act = h_swish() + + self.bn4 = nn.BatchNorm2d(mip) + self.bn5 = nn.BatchNorm2d(mip) + + self.bn6 = nn.BatchNorm2d(1) + self.bn7 = nn.BatchNorm2d(1) + + self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) + self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = torch.unsqueeze(x, 1) #2 1 2304 196 + identity = x + + n, c, h, w = x.size()#2 1 2304 196 + x_h = self.bn2(self.pool_h(x))#2 1 2304 1 + x_w = self.bn3(self.pool_w(x).permute(0, 1, 3, 2)) #2 1 196 1 + identity_x_w = x_w + identity_x_h = x_h + y = torch.cat([x_h, x_w], dim=2) + y = self.conv1(y) #2 8 2500 1 + y = self.bn1(y) + y = self.act(y) + + x_h, x_w = torch.split(y, [h, w], dim=2) #2 8 2304 1 | 2 8 196 1 + x_h = self.bn4(x_h)+identity_x_h + x_w = self.bn5(x_w)+identity_x_w + x_w = x_w.permute(0, 1, 3, 2) + + a_h = self.bn6(self.conv_h(x_h)).sigmoid() #2 1 2304 1 + a_w = self.bn7(self.conv_w(x_w)).sigmoid() #24 1 1 196 + + out = identity * a_w * a_h #点× + out = torch.squeeze(out, 1) + return out + +class CoordAtt_ori(nn.Module): + def __init__(self, inp, oup, reduction=32): + super(CoordAtt_ori, self).__init__() + self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) + self.pool_w = nn.AdaptiveAvgPool2d((1, None)) + + mip = max(8, inp // reduction) + + self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) + self.bn1 = nn.BatchNorm2d(mip) + self.act = h_swish() + + self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) + self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = torch.unsqueeze(x, 1) + identity = x + + n, c, h, w = x.size() + x_h = self.pool_h(x) + x_w = self.pool_w(x).permute(0, 1, 3, 2) + + y = torch.cat([x_h, x_w], dim=2) + y = self.conv1(y) + y = self.bn1(y) + y = self.act(y) + + x_h, x_w = torch.split(y, [h, w], dim=2) + x_w = x_w.permute(0, 1, 3, 2) + + a_h = self.conv_h(x_h).sigmoid() + a_w = self.conv_w(x_w).sigmoid() + + out = identity * a_w * a_h + out = torch.squeeze(out, 1) + return out \ No newline at end of file diff --git a/model/Vision_Transformer_with_mask.py b/model/Vision_Transformer_with_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..deec6341cd9434e52f7b976a081143084418f888 --- /dev/null +++ b/model/Vision_Transformer_with_mask.py @@ -0,0 +1,990 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), + 'deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), +} + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 #这行多了个qk_scale #0.125 + + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(dim, dim, bias=qkv_bias) + self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + B, N, C = x.shape #2 512 768 + q = self.wq(x[:, 0:int(N/2), ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#2 12 256 64 + k = self.wk(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.wv(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, int(N/2), C) #变成了B/2 2 256 768 + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, data): + b,c,h = data.shape + x,atten_mask = data[:,0:int(c/2),...],data[:,int(c/2):,...] + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple) + + + attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49 + if atten_mask.sum() != 0: + atten_mask = atten_mask.unsqueeze(1) # 2,1,49,49 + attn = attn + atten_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Attention_ori(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple) + attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49 + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, data): + b,c,h = data.shape + x,mask = data[:,0:int(c/2),...],data[:,int(c/2):,...] + x = x + self.drop_path(self.attn(torch.cat([self.norm1(x),mask],dim=1))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return torch.cat([x,mask],dim=1) + +class mask_PatchEmbed(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.flatten = flatten + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(in_chans, 1, kernel_size=patch_size, stride=patch_size).requires_grad_(False) + nn.init.ones_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + return x + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None,as_backbone=True, weight_init=''): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + self.num_heads = num_heads + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.as_backbone = as_backbone #是否分类任务,如果不是,class不加上去 + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.mask_embed = mask_PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + if not self.as_backbone: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + else: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + if not self.as_backbone: + self.avgpool = nn.AdaptiveAvgPool1d(1) + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, data): + x,mask = data[:,0,:,:].unsqueeze(1),data[:,1,:,:].unsqueeze(1) + x = self.patch_embed(x)#B N C + atten_mask = torch.zeros_like(x) # 2 49 768 + if mask.sum() != 0: + mask = self.mask_embed(mask) ### + mask.squeeze_(dim=2) + mask[mask != 0] = 1 ### H W数目token C编码长度 + k1 = mask[:, None, :] + k2 = torch.ones_like(mask)[:, :, None] + k3 = k1 * k2 + atten_mask = (1.0 - k3) * (-1e6) + atten_mask.requires_grad_(True) + self.atten_mask = atten_mask + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if not self.as_backbone: + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) #2 49 768 + x = self.blocks(torch.cat([x,atten_mask],dim=1)) + b,c,h = x.shape + x = x[:,0:int(c/2),...] + x = self.norm(x) + if self.as_backbone: + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return x + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + def forward(self, data): + x = self.forward_features(data) #2 49 768 + if self.as_backbone: + return x + else: + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + default_cfg=default_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], + **kwargs) + return model + + +@register_model +def vit_tiny_patch16_224(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_384(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer( + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer( + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + return model + + diff --git a/model/__pycache__/CoordAttention.cpython-38.pyc b/model/__pycache__/CoordAttention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..597faea287f5aa9a343b049fecef1b8c21d7ed61 Binary files /dev/null and b/model/__pycache__/CoordAttention.cpython-38.pyc differ diff --git a/model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc b/model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a39c33373ffe564cf0ae7906d724a7e2c3d970 Binary files /dev/null and b/model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc differ diff --git a/model/__pycache__/features.cpython-38.pyc b/model/__pycache__/features.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7608fbde80d8091347292327d8be744f1ae3062 Binary files /dev/null and b/model/__pycache__/features.cpython-38.pyc differ diff --git a/model/__pycache__/helpers.cpython-38.pyc b/model/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c33e7d14712c0b579353a8ff278d6f5cbcd28a9 Binary files /dev/null and b/model/__pycache__/helpers.cpython-38.pyc differ diff --git a/model/__pycache__/hub.cpython-38.pyc b/model/__pycache__/hub.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73f79fde4efdd4cdfd8852bb7fc6d3539ceb0a51 Binary files /dev/null and b/model/__pycache__/hub.cpython-38.pyc differ diff --git a/model/__pycache__/registry.cpython-38.pyc b/model/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c51c9455b52127336689cc24143843a8394fa451 Binary files /dev/null and b/model/__pycache__/registry.cpython-38.pyc differ diff --git a/model/features.py b/model/features.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d6890f3ed07311c5484b4a397c3b1da555880a --- /dev/null +++ b/model/features.py @@ -0,0 +1,284 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/model/helpers.py b/model/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..662a7a483b1e40f9f00d931e84762878c612c0c6 --- /dev/null +++ b/model/helpers.py @@ -0,0 +1,508 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +import math +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Callable, Optional, Tuple + +import torch +import torch.nn as nn + + +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url +from .layers import Conv2dSame, Linear + + +_logger = logging.getLogger(__name__) + + +def load_state_dict(checkpoint_path, use_ema=False): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = 'state_dict' + if isinstance(checkpoint, dict): + if use_ema and 'state_dict_ema' in checkpoint: + state_dict_key = 'state_dict_ema' + if state_dict_key and state_dict_key in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + # strip `module.` prefix + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + state_dict = new_state_dict + else: + state_dict = checkpoint + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict, strict=strict) + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + default_cfg (dict): Default pretrained model cfg + load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + progress (bool, optional): whether or not to display a progress bar to stderr. Default: False + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. Default: False + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + if not pretrained_url: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) + + if load_fn is not None: + load_fn(model, cached_file) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(cached_file) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset + num_classes (int): num_classes for model + in_chans (int): in_chans for model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + progress (bool): enable progress bar for weight download + + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + hf_hub_id = default_cfg.get('hf_hub', None) + if not pretrained_url and not hf_hub_id: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if hf_hub_id and has_hf_hub(necessary=not pretrained_url): + _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') + state_dict = load_state_dict_from_hf(hf_hub_id) + else: + _logger.info(f'Loading pretrained weights from url ({pretrained_url})') + state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = default_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = default_cfg.get('classifier', None) + label_offset = default_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != default_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + if isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + if isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) + + +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def set_default_kwargs(kwargs, names, default_cfg): + for n in names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # default_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = default_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, default_cfg[n]) + + +def filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs + could/should be replaced by an improved configuration mechanism + + Args: + default_cfg: input default_cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Overlay default cfg values from `external_default_cfg` if it exists in kwargs + overlay_external_default_cfg(default_cfg, kwargs) + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + default_cfg: dict, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + pretrained_custom_load: bool = False, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + default_cfg (dict): model's default pretrained/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + default_cfg = deepcopy(default_cfg) if default_cfg else {} + update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) + default_cfg.setdefault('architecture', variant) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Build the model + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.default_cfg = default_cfg + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_custom_load: + load_custom_pretrained(model) + else: + load_pretrained( + model, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg + + return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module diff --git a/model/hub.py b/model/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9b553031fb9d1846338990cd3b6f77228174c6 --- /dev/null +++ b/model/hub.py @@ -0,0 +1,96 @@ +import json +import logging +import os +from functools import partial +from typing import Union, Optional + +import torch +from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +try: + from huggingface_hub import hf_hub_url + from huggingface_hub import cached_download + cached_download = partial(cached_download, library_name="timm", library_version=__version__) +except ImportError: + hf_hub_url = None + cached_download = None + +_logger = logging.getLogger(__name__) + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if hf_hub_url is None and necessary: + # if no HF Hub module installed and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return hf_hub_url is not None + + +def hf_split(hf_id): + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + url = hf_hub_url(hf_model_id, filename, revision=hf_revision) + return cached_download(url, cache_dir=get_cache_dir('hf')) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + default_cfg = load_cfg_from_json(cached_file) + default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation + model_name = default_cfg.get('architecture') + return default_cfg, model_name + + +def load_state_dict_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict diff --git a/model/layers/__init__.py b/model/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77d1026e8c88c466250df4fe8e2bbcc9e2b42290 --- /dev/null +++ b/model/layers/__init__.py @@ -0,0 +1,40 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .inplace_abn import InplaceAbn +from .involution import Involution +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/model/layers/__pycache__/__init__.cpython-38.pyc b/model/layers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db2de0cc6a503e2a956e73f82a46a8944cfc1fea Binary files /dev/null and b/model/layers/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/layers/__pycache__/activations.cpython-38.pyc b/model/layers/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b54dbd8471be1fd5ecbcf564f7cd30754bf1f363 Binary files /dev/null and b/model/layers/__pycache__/activations.cpython-38.pyc differ diff --git a/model/layers/__pycache__/activations_jit.cpython-38.pyc b/model/layers/__pycache__/activations_jit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14052cf232fb0cbb2583c568fda16f600b885249 Binary files /dev/null and b/model/layers/__pycache__/activations_jit.cpython-38.pyc differ diff --git a/model/layers/__pycache__/activations_me.cpython-38.pyc b/model/layers/__pycache__/activations_me.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af6d17c34808ade51ce15fac6bbf71130882c23 Binary files /dev/null and b/model/layers/__pycache__/activations_me.cpython-38.pyc differ diff --git a/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc b/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f083535a25fb8d6580d230a4d38acb7a11aad07 Binary files /dev/null and b/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc differ diff --git a/model/layers/__pycache__/blur_pool.cpython-38.pyc b/model/layers/__pycache__/blur_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24b21a9f9ed8beff996577901d3df21b332eafe Binary files /dev/null and b/model/layers/__pycache__/blur_pool.cpython-38.pyc differ diff --git a/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc b/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51c8a2c139447acea017be2bd485893e9e0f82db Binary files /dev/null and b/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/cbam.cpython-38.pyc b/model/layers/__pycache__/cbam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc471a412b2c5bf8ae48ede1490b47d85cff6566 Binary files /dev/null and b/model/layers/__pycache__/cbam.cpython-38.pyc differ diff --git a/model/layers/__pycache__/classifier.cpython-38.pyc b/model/layers/__pycache__/classifier.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5889f225cf66bd8b3222224a95703222810a11c Binary files /dev/null and b/model/layers/__pycache__/classifier.cpython-38.pyc differ diff --git a/model/layers/__pycache__/cond_conv2d.cpython-38.pyc b/model/layers/__pycache__/cond_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c1ed4bf7ad335cc35f808b256938bff69d1a85 Binary files /dev/null and b/model/layers/__pycache__/cond_conv2d.cpython-38.pyc differ diff --git a/model/layers/__pycache__/config.cpython-38.pyc b/model/layers/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d128a639f93a1f48b7d32f2330f7dd946c097a75 Binary files /dev/null and b/model/layers/__pycache__/config.cpython-38.pyc differ diff --git a/model/layers/__pycache__/conv2d_same.cpython-38.pyc b/model/layers/__pycache__/conv2d_same.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0694f24db820537ffd85472ecc6ae1414a0fc04 Binary files /dev/null and b/model/layers/__pycache__/conv2d_same.cpython-38.pyc differ diff --git a/model/layers/__pycache__/conv_bn_act.cpython-38.pyc b/model/layers/__pycache__/conv_bn_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89fef189b78f5a8cdfad0c918e438b6e56f07b7c Binary files /dev/null and b/model/layers/__pycache__/conv_bn_act.cpython-38.pyc differ diff --git a/model/layers/__pycache__/create_act.cpython-38.pyc b/model/layers/__pycache__/create_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2f401db090872237dadc82c7908704b1afce6a Binary files /dev/null and b/model/layers/__pycache__/create_act.cpython-38.pyc differ diff --git a/model/layers/__pycache__/create_attn.cpython-38.pyc b/model/layers/__pycache__/create_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c5541433d2b30e9b9eeceebb862dce16e4020e4 Binary files /dev/null and b/model/layers/__pycache__/create_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/create_conv2d.cpython-38.pyc b/model/layers/__pycache__/create_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f51598af64479a447affbd1d36b1ec10f78b8a5 Binary files /dev/null and b/model/layers/__pycache__/create_conv2d.cpython-38.pyc differ diff --git a/model/layers/__pycache__/create_norm_act.cpython-38.pyc b/model/layers/__pycache__/create_norm_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587de848537403f6a7cd85b537c173c942581f49 Binary files /dev/null and b/model/layers/__pycache__/create_norm_act.cpython-38.pyc differ diff --git a/model/layers/__pycache__/drop.cpython-38.pyc b/model/layers/__pycache__/drop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7411084e69378c1e23826cf75673299c6b0698 Binary files /dev/null and b/model/layers/__pycache__/drop.cpython-38.pyc differ diff --git a/model/layers/__pycache__/eca.cpython-38.pyc b/model/layers/__pycache__/eca.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29210d53150e8569d879ba6b7ff711adcc1e6d90 Binary files /dev/null and b/model/layers/__pycache__/eca.cpython-38.pyc differ diff --git a/model/layers/__pycache__/evo_norm.cpython-38.pyc b/model/layers/__pycache__/evo_norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8599f8ae269df2b356a43b3acad29164ba1e8600 Binary files /dev/null and b/model/layers/__pycache__/evo_norm.cpython-38.pyc differ diff --git a/model/layers/__pycache__/gather_excite.cpython-38.pyc b/model/layers/__pycache__/gather_excite.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645c5fc1de56775d9c42d1d6b30b513166bdfabe Binary files /dev/null and b/model/layers/__pycache__/gather_excite.cpython-38.pyc differ diff --git a/model/layers/__pycache__/global_context.cpython-38.pyc b/model/layers/__pycache__/global_context.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..190ca09ca3d43b09d6a4a5a0eca7891a2715b51a Binary files /dev/null and b/model/layers/__pycache__/global_context.cpython-38.pyc differ diff --git a/model/layers/__pycache__/halo_attn.cpython-38.pyc b/model/layers/__pycache__/halo_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2064d645da45ff732ff4e184c6ab16131d03f5e Binary files /dev/null and b/model/layers/__pycache__/halo_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/helpers.cpython-38.pyc b/model/layers/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8a9bcd112ec399513dc4f73ae3d2829be8ae854 Binary files /dev/null and b/model/layers/__pycache__/helpers.cpython-38.pyc differ diff --git a/model/layers/__pycache__/inplace_abn.cpython-38.pyc b/model/layers/__pycache__/inplace_abn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22fb07ba40f6a334f72cd025186a1061a1a8a8e3 Binary files /dev/null and b/model/layers/__pycache__/inplace_abn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/involution.cpython-38.pyc b/model/layers/__pycache__/involution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..131b7774e7489e9633710eab4f159991a7fb992e Binary files /dev/null and b/model/layers/__pycache__/involution.cpython-38.pyc differ diff --git a/model/layers/__pycache__/lambda_layer.cpython-38.pyc b/model/layers/__pycache__/lambda_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d39173c6e37c8ddf69e9f9cf5b639e329fd42a7 Binary files /dev/null and b/model/layers/__pycache__/lambda_layer.cpython-38.pyc differ diff --git a/model/layers/__pycache__/linear.cpython-38.pyc b/model/layers/__pycache__/linear.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a692202b565fcd2d453bec64c7c7a9cd1b5bd38 Binary files /dev/null and b/model/layers/__pycache__/linear.cpython-38.pyc differ diff --git a/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc b/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c014a5ff4fdc8b9da0614752f22cef392141f91d Binary files /dev/null and b/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc differ diff --git a/model/layers/__pycache__/mlp.cpython-38.pyc b/model/layers/__pycache__/mlp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cf6f06d6bc545ac6e5c45650da58002589fa18b Binary files /dev/null and b/model/layers/__pycache__/mlp.cpython-38.pyc differ diff --git a/model/layers/__pycache__/non_local_attn.cpython-38.pyc b/model/layers/__pycache__/non_local_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e5bebabc5a5c5001dba43c162b53da7531f4fc Binary files /dev/null and b/model/layers/__pycache__/non_local_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/norm.cpython-38.pyc b/model/layers/__pycache__/norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd1e1c30d95318ba4143cbb7f3efbee8c76af030 Binary files /dev/null and b/model/layers/__pycache__/norm.cpython-38.pyc differ diff --git a/model/layers/__pycache__/norm_act.cpython-38.pyc b/model/layers/__pycache__/norm_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b44adbeb28b0ce1ef9ce963850f7f5169c3ff43 Binary files /dev/null and b/model/layers/__pycache__/norm_act.cpython-38.pyc differ diff --git a/model/layers/__pycache__/padding.cpython-38.pyc b/model/layers/__pycache__/padding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be8bc32e7ba18cc9f0c61837fe67f1b877bffed0 Binary files /dev/null and b/model/layers/__pycache__/padding.cpython-38.pyc differ diff --git a/model/layers/__pycache__/patch_embed.cpython-38.pyc b/model/layers/__pycache__/patch_embed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4bd85a4695e30dccf20a7f0ea0a67c9f3fb066f Binary files /dev/null and b/model/layers/__pycache__/patch_embed.cpython-38.pyc differ diff --git a/model/layers/__pycache__/pool2d_same.cpython-38.pyc b/model/layers/__pycache__/pool2d_same.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05829d2865f957675c58b5cb478b282a58fb64f6 Binary files /dev/null and b/model/layers/__pycache__/pool2d_same.cpython-38.pyc differ diff --git a/model/layers/__pycache__/selective_kernel.cpython-38.pyc b/model/layers/__pycache__/selective_kernel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d7f40e5fbc05330ab5dca9a965309648e78bd0a Binary files /dev/null and b/model/layers/__pycache__/selective_kernel.cpython-38.pyc differ diff --git a/model/layers/__pycache__/separable_conv.cpython-38.pyc b/model/layers/__pycache__/separable_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a11b2ec0e1a1c183e7b73c93bcd1db19ae53442 Binary files /dev/null and b/model/layers/__pycache__/separable_conv.cpython-38.pyc differ diff --git a/model/layers/__pycache__/space_to_depth.cpython-38.pyc b/model/layers/__pycache__/space_to_depth.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad5b3fdd2ba853732c9d333aad976e294e46139 Binary files /dev/null and b/model/layers/__pycache__/space_to_depth.cpython-38.pyc differ diff --git a/model/layers/__pycache__/split_attn.cpython-38.pyc b/model/layers/__pycache__/split_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeea45a38d677d23888f1269217359614f84a9ca Binary files /dev/null and b/model/layers/__pycache__/split_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/split_batchnorm.cpython-38.pyc b/model/layers/__pycache__/split_batchnorm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a276d1be87b3c962366e01fb982507f2d21aada Binary files /dev/null and b/model/layers/__pycache__/split_batchnorm.cpython-38.pyc differ diff --git a/model/layers/__pycache__/squeeze_excite.cpython-38.pyc b/model/layers/__pycache__/squeeze_excite.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a4ec4fdda1755998eef7335012566a1ad67c395 Binary files /dev/null and b/model/layers/__pycache__/squeeze_excite.cpython-38.pyc differ diff --git a/model/layers/__pycache__/std_conv.cpython-38.pyc b/model/layers/__pycache__/std_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9618fe4e5699fc9c9ff401a18718dac72383bb1 Binary files /dev/null and b/model/layers/__pycache__/std_conv.cpython-38.pyc differ diff --git a/model/layers/__pycache__/swin_attn.cpython-38.pyc b/model/layers/__pycache__/swin_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7843298527950867d9ec5b33ad49f49454f9ef2c Binary files /dev/null and b/model/layers/__pycache__/swin_attn.cpython-38.pyc differ diff --git a/model/layers/__pycache__/test_time_pool.cpython-38.pyc b/model/layers/__pycache__/test_time_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..378a21a16a409244fe1281f45362b702413e68f2 Binary files /dev/null and b/model/layers/__pycache__/test_time_pool.cpython-38.pyc differ diff --git a/model/layers/__pycache__/weight_init.cpython-38.pyc b/model/layers/__pycache__/weight_init.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cc48716b757a9c699e717fddb00686ef3496ac4 Binary files /dev/null and b/model/layers/__pycache__/weight_init.cpython-38.pyc differ diff --git a/model/layers/activations.py b/model/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..e16b3bd3a1898365530c1ffc5154a0a4746a136e --- /dev/null +++ b/model/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/model/layers/activations_jit.py b/model/layers/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a516530ad0abf41f720ac83d02791179bb7b67 --- /dev/null +++ b/model/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/model/layers/activations_me.py b/model/layers/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..9a12bb7ebbfef02c508801742d38da6b48dd1bb6 --- /dev/null +++ b/model/layers/activations_me.py @@ -0,0 +1,218 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/model/layers/adaptive_avgmax_pool.py b/model/layers/adaptive_avgmax_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6ada8c5b28c7eac5785b0cc2933eb01a15d46 --- /dev/null +++ b/model/layers/adaptive_avgmax_pool.py @@ -0,0 +1,118 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/model/layers/blur_pool.py b/model/layers/blur_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4ce756e434d577c38a20e2e8de2909777862d4 --- /dev/null +++ b/model/layers/blur_pool.py @@ -0,0 +1,42 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) diff --git a/model/layers/bottleneck_attn.py b/model/layers/bottleneck_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9604e8a6cfb992c50bc1fc15c54979f30b1d2c94 --- /dev/null +++ b/model/layers/bottleneck_attn.py @@ -0,0 +1,126 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, num_heads, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B * num_heads, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + """ + def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_out = dim_out + self.dim_head = dim_out // num_heads + self.scale = self.dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.pos_embed.height and W == self.pos_embed.width + + x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W + x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) + q, k, v = torch.split(x, self.num_heads, dim=1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale + attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + + attn_out = attn_logits.softmax(dim = -1) + attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = self.pool(attn_out) + return attn_out + + diff --git a/model/layers/cbam.py b/model/layers/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..bacf5cf07b695ce6c5fd87facc79f6a5773e6ecf --- /dev/null +++ b/model/layers/cbam.py @@ -0,0 +1,112 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(ChannelAttn, self).__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * F.sigmoid(x_attn) + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class CbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(CbamModule, self).__init__() + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/model/layers/classifier.py b/model/layers/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2b74541341ad24bfb97f7ea90ac6470b83a73aa3 --- /dev/null +++ b/model/layers/classifier.py @@ -0,0 +1,56 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear + + +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + else: + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x diff --git a/model/layers/cond_conv2d.py b/model/layers/cond_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4bbca84d6f12e0fb875b4edb435b976fc649d6 --- /dev/null +++ b/model/layers/cond_conv2d.py @@ -0,0 +1,122 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/model/layers/config.py b/model/layers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f07b9d782ba0597c174dee81097c28280335fdba --- /dev/null +++ b/model/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/model/layers/conv2d_same.py b/model/layers/conv2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..75f0f98d4ec1e3f4a0dc004b977815afaa25e7fc --- /dev/null +++ b/model/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/model/layers/conv_bn_act.py b/model/layers/conv_bn_act.py new file mode 100644 index 0000000000000000000000000000000000000000..33005c37b752bd995aeb983ad8480c36b94d0a0c --- /dev/null +++ b/model/layers/conv_bn_act.py @@ -0,0 +1,40 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, + drop_block=None): + super(ConvBnAct, self).__init__() + use_aa = aa_layer is not None + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.aa is not None: + x = self.aa(x) + return x diff --git a/model/layers/create_act.py b/model/layers/create_act.py new file mode 100644 index 0000000000000000000000000000000000000000..aa557692accff431fe1f9cfb7a5c6d94314b14f6 --- /dev/null +++ b/model/layers/create_act.py @@ -0,0 +1,153 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Union, Callable, Type + +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +def get_act_fn(name: Union[Callable, str] = 'relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, type): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/model/layers/create_attn.py b/model/layers/create_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed646b3378774c3155cf0d01651a043164ef21 --- /dev/null +++ b/model/layers/create_attn.py @@ -0,0 +1,93 @@ +""" Attention Factory + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from functools import partial + +from .bottleneck_attn import BottleneckAttn +from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .involution import Involution +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn +from .squeeze_excite import SEModule, EffectiveSEModule +from .swin_attn import WindowAttention + + +def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'swin': + return WindowAttention + elif attn_type == 'involution': + return Involution + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) + if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels + return module_cls(channels, **kwargs) + return None diff --git a/model/layers/create_conv2d.py b/model/layers/create_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0cc03a5c8c23fe047d1d3c24782700422e2e6e --- /dev/null +++ b/model/layers/create_conv2d.py @@ -0,0 +1,31 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + assert 'groups' not in kwargs # MixedConv groups are defined by kernel list + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/model/layers/create_norm_act.py b/model/layers/create_norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5629457dc14b5da3b9673b7e21d7d80f7cda4c --- /dev/null +++ b/model/layers/create_norm_act.py @@ -0,0 +1,83 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +import torch +import torch.nn as nn + +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .inplace_abn import InplaceAbn + +_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} +_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type + + +def get_norm_act_layer(layer_class): + layer_class = layer_class.replace('_', '').lower() + if layer_class.startswith("batchnorm"): + layer = BatchNormAct2d + elif layer_class.startswith("groupnorm"): + layer = GroupNormAct + elif layer_class == "evonormbatch": + layer = EvoNormBatch2d + elif layer_class == "evonormsample": + layer = EvoNormSample2d + elif layer_class == "iabn" or layer_class == "inplaceabn": + layer = InplaceAbn + else: + assert False, "Invalid norm_act layer (%s)" % layer_class + return layer + + +def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu + assert len(layer_parts) in (1, 2) + layer = get_norm_act_layer(layer_parts[0]) + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def convert_norm_act(norm_layer, act_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + norm_act_layer = get_norm_act_layer(norm_layer) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/model/layers/drop.py b/model/layers/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..6de9e3f729f7f1ca29d4511f6c64733d3169fbec --- /dev/null +++ b/model/layers/drop.py @@ -0,0 +1,168 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/model/layers/eca.py b/model/layers/eca.py new file mode 100644 index 0000000000000000000000000000000000000000..e29be6ac3c95bb61229cdcdd659ec89d541f1a53 --- /dev/null +++ b/model/layers/eca.py @@ -0,0 +1,145 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): + super(EcaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +EfficientChannelAttn = EcaModule # alias + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + super(CecaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule diff --git a/model/layers/evo_norm.py b/model/layers/evo_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9023afd0e81dc8a76871d03141866217d59f4770 --- /dev/null +++ b/model/layers/evo_norm.py @@ -0,0 +1,83 @@ +"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch + +An attempt at getting decent performing EvoNorms running in PyTorch. +While currently faster than other impl, still quite a ways off the built-in BN +in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). + +Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn + + +class EvoNormBatch2d(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): + super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + x_type = x.dtype + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + self.running_var.copy_( + var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + else: + var = self.running_var + + if self.apply_act: + v = self.v.to(dtype=x_type) + d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) + x = x / d + return x * self.weight + self.bias + + +class EvoNormSample2d(nn.Module): + def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): + super(EvoNormSample2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.groups = groups + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + B, C, H, W = x.shape + assert C % self.groups == 0 + if self.apply_act: + n = x * (x * self.v).sigmoid() + x = x.reshape(B, self.groups, -1) + x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + x = x.reshape(B, C, H, W) + return x * self.weight + self.bias diff --git a/model/layers/gather_excite.py b/model/layers/gather_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..2d60dc961e2b5e135d38e290b8fa5820ef0fe18f --- /dev/null +++ b/model/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/model/layers/global_context.py b/model/layers/global_context.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2c82f3aa75f8fedad49305952667f9a6fd5363 --- /dev/null +++ b/model/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/model/layers/halo_attn.py b/model/layers/halo_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..87cae8952cb7318cbec9bc513e7b2010ede7312d --- /dev/null +++ b/model/layers/halo_attn.py @@ -0,0 +1,166 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. + +Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model +is extremely slow. Something isn't right. However, the models do appear to train and experimental +variants with attn in C4 and/or C5 stages are tolerable speed. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Tuple, List + +import torch +from torch import nn +import torch.nn.functional as F + +from .weight_init import trunc_normal_ + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + """ + def __init__( + self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.stride = stride + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_qk = num_heads * dim_head + self.dim_v = dim_out + self.block_size = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.scale = self.dim_head ** -0.5 + + # FIXME not clear if this stride behaviour is what the paper intended + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H % self.block_size == 0 and W % self.block_size == 0 + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # FIXME I 'think' this unfold does what I want it to, but I should investigate + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( + B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) + k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? + attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 + + attn_out = attn_logits.softmax(dim=-1) + attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + attn_out = F.fold( + attn_out.reshape(B, -1, num_blocks), + (H // self.stride, W // self.stride), + kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, dim_out, H // stride, W // stride + return attn_out diff --git a/model/layers/helpers.py b/model/layers/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..cc54ca7f8a24de7e1ee0e5d27decf3e88c55ece3 --- /dev/null +++ b/model/layers/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v diff --git a/model/layers/inplace_abn.py b/model/layers/inplace_abn.py new file mode 100644 index 0000000000000000000000000000000000000000..3aae7cf563edfe6c9d2bf1a9f3994d911aacea23 --- /dev/null +++ b/model/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_block=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/model/layers/involution.py b/model/layers/involution.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdeefcbe96cabb9285e08408a447ce8a89435db --- /dev/null +++ b/model/layers/involution.py @@ -0,0 +1,50 @@ +""" PyTorch Involution Layer + +Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py +Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255 +""" +import torch.nn as nn +from .conv_bn_act import ConvBnAct +from .create_conv2d import create_conv2d + + +class Involution(nn.Module): + + def __init__( + self, + channels, + kernel_size=3, + stride=1, + group_size=16, + rd_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.channels = channels + self.group_size = group_size + self.groups = self.channels // self.group_size + self.conv1 = ConvBnAct( + in_channels=channels, + out_channels=channels // rd_ratio, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer) + self.conv2 = self.conv = create_conv2d( + in_channels=channels // rd_ratio, + out_channels=kernel_size**2 * self.groups, + kernel_size=1, + stride=1) + self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity() + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride) + + def forward(self, x): + weight = self.conv2(self.conv1(self.avgpool(x))) + B, C, H, W = weight.shape + KK = int(self.kernel_size ** 2) + weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2) + out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W) + out = (weight * out).sum(dim=3).view(B, self.channels, H, W) + return out diff --git a/model/layers/lambda_layer.py b/model/layers/lambda_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1027a18146f3171724a45d82107c77e1297e5c --- /dev/null +++ b/model/layers/lambda_layer.py @@ -0,0 +1,84 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + +from .weight_init import trunc_normal_ + + +class LambdaLayer(nn.Module): + """Lambda Layer w/ lambda conv position embedding + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + """ + def __init__( + self, + dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + super().__init__() + self.dim = dim + self.dim_out = dim_out or dim + self.dim_k = dim_head # query depth 'k' + self.num_heads = num_heads + assert self.dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_v = self.dim_out // num_heads # value depth 'v' + self.r = r # relative position neighbourhood (lambda conv kernel size) + + self.qkv = nn.Conv2d( + dim, + num_heads * dim_head + dim_head + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + # NOTE currently only supporting the local lambda convolutions for positional + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out diff --git a/model/layers/linear.py b/model/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..38fe3380b067ea0b275c45ffd689afdeb4598f3c --- /dev/null +++ b/model/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/model/layers/median_pool.py b/model/layers/median_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..40bd71a7a3840aaebefd2af0a99605b845054cd7 --- /dev/null +++ b/model/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/model/layers/mixed_conv2d.py b/model/layers/mixed_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0ce565c0a9d348d4e68165960fa77fcf7f70d7 --- /dev/null +++ b/model/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/model/layers/mlp.py b/model/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..05d076527cfb6f15bcf5f2830fa36777abbc5a1e --- /dev/null +++ b/model/layers/mlp.py @@ -0,0 +1,108 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.drop = nn.Dropout(drop) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/model/layers/non_local_attn.py b/model/layers/non_local_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a537d60e6e575f093b93a146b83fb8e6398f6288 --- /dev/null +++ b/model/layers/non_local_attn.py @@ -0,0 +1,143 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t: int): + B, C, block_size, block_size1 = x.shape + assert block_size == block_size1 + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/model/layers/norm.py b/model/layers/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..433552b4cec1e901147d61b05ed6c68ea9c3799f --- /dev/null +++ b/model/layers/norm.py @@ -0,0 +1,23 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ Layernorm for channels of '2d' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__([num_channels, 1, 1]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/model/layers/norm_act.py b/model/layers/norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..02cabe88861f96345599b71a4a96edd8d115f6d3 --- /dev/null +++ b/model/layers/norm_act.py @@ -0,0 +1,85 @@ +""" Normalization + Activation Layers +""" +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .create_act import get_act_layer + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _forward_jit(self, x): + """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function + """ + # exponential_average_factor is self.momentum set to + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + return x + + @torch.jit.ignore + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + if torch.jit.is_scripting(): + x = self._forward_jit(x) + else: + x = self._forward_python(x) + x = self.act(x) + return x + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.act(x) + return x diff --git a/model/layers/padding.py b/model/layers/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..34afc37c6c59c8782ad29c7a779f58177011f891 --- /dev/null +++ b/model/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/model/layers/patch_embed.py b/model/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..42997fb89f10d518028e064c46387f694dce9026 --- /dev/null +++ b/model/layers/patch_embed.py @@ -0,0 +1,39 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from torch import nn as nn + +from .helpers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/model/layers/pool2d_same.py b/model/layers/pool2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2a1c44713e552be850865ada9623a1c3b1d836 --- /dev/null +++ b/model/layers/pool2d_same.py @@ -0,0 +1,73 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/model/layers/selective_kernel.py b/model/layers/selective_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..f28b8d2e9ad49740081d4e1da5287e45f5ee76b8 --- /dev/null +++ b/model/layers/selective_kernel.py @@ -0,0 +1,119 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernel(nn.Module): + + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + rd_ratio (int, float): reduction factor for attention features + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + drop_block (nn.Module): drop block module + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + """ + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/model/layers/separable_conv.py b/model/layers/separable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddcb4e62409492f898ab963027a9c2229b72f64 --- /dev/null +++ b/model/layers/separable_conv.py @@ -0,0 +1,73 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class SeparableConvBnAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_block=None): + super(SeparableConvBnAct, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/model/layers/space_to_depth.py b/model/layers/space_to_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e8e0b2a486d51fe3e4ab0472d89b7f1b92e1dc --- /dev/null +++ b/model/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/model/layers/split_attn.py b/model/layers/split_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..dde601befa933727e169d9b84b035cf1f035e67c --- /dev/null +++ b/model/layers/split_attn.py @@ -0,0 +1,85 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .helpers import make_divisible + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) + """ + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels + self.radix = radix + self.drop_block = drop_block + mid_chs = out_channels * radix + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + + padding = kernel_size // 2 if padding is None else padding + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.fc1(x_gap) + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/model/layers/split_batchnorm.py b/model/layers/split_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..830781b335161f8d6dd74c9458070bb1fa88a918 --- /dev/null +++ b/model/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/model/layers/squeeze_excite.py b/model/layers/squeeze_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..e5da29ef166de27705cc160f729b6e3b45061c59 --- /dev/null +++ b/model/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/model/layers/std_conv.py b/model/layers/std_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccc16e1197a41440add454a40ed3146ed0b6211 --- /dev/null +++ b/model/layers/std_conv.py @@ -0,0 +1,133 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .padding import get_padding, get_padding_value, pad_same + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/model/layers/swin_attn.py b/model/layers/swin_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..02131bbc4dec3f726a23da0444bec108f9c3903a --- /dev/null +++ b/model/layers/swin_attn.py @@ -0,0 +1,182 @@ +""" Shifted Window Attn + +This is a WIP experiment to apply windowed attention from the Swin Transformer +to a stand-alone module for use as an attn block in conv nets. + +Based on original swin window code at https://github.com/microsoft/Swin-Transformer +Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf +""" +from typing import Optional + +import torch +import torch.nn as nn + +from .drop import DropPath +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def window_partition(x, win_size: int): + """ + Args: + x: (B, H, W, C) + win_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) + return windows + + +def window_reverse(windows, win_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + win_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + win_size (int): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + """ + + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, + qkv_bias=True, attn_drop=0.): + + super().__init__() + self.dim_out = dim_out or dim + self.feat_size = to_2tuple(feat_size) + self.win_size = win_size + self.shift_size = shift_size or win_size // 2 + if min(self.feat_size) <= win_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.win_size = min(self.feat_size) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" + self.num_heads = num_heads + head_dim = self.dim_out // num_heads + self.scale = head_dim ** -0.5 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.win_size * self.win_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + # 2 * Wh - 1 * 2 * Ww - 1, nH + torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size) + coords_w = torch.arange(self.win_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size - 1 + relative_coords[:, :, 0] *= 2 * self.win_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + win_size_sq = self.win_size * self.win_size + x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C + x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C + BW, N, _ = x_windows.shape + + qkv = self.qkv(x_windows) + qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww + attn = attn + relative_position_bias.unsqueeze(0) + if self.attn_mask is not None: + num_win = self.attn_mask.shape[0] + attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(BW, N, self.dim_out) + + # merge windows + x = x.view(-1, self.win_size, self.win_size, self.dim_out) + shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) + x = self.pool(x) + return x + + diff --git a/model/layers/test_time_pool.py b/model/layers/test_time_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..98c0bf53a74eb954a25b96d84712ef974eb8ea3b --- /dev/null +++ b/model/layers/test_time_pool.py @@ -0,0 +1,52 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config, use_test_size=True): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if use_test_size and 'test_input_size' in model.default_cfg: + df_input_size = model.default_cfg['test_input_size'] + else: + df_input_size = model.default_cfg['input_size'] + if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(df_input_size[-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/model/layers/weight_init.py b/model/layers/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..305a2fd067e7104e58b9b5ff70d96e89a06050af --- /dev/null +++ b/model/layers/weight_init.py @@ -0,0 +1,89 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/model/model_epoch62.pth.tar b/model/model_epoch62.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..9dadd925469e23f8fc4a6fab89ea2bf2d67d84a4 --- /dev/null +++ b/model/model_epoch62.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:257c061b84e7b675a2296e7c452ff9cd5363ed462684e3cff470bc777021af64 +size 805719721 diff --git a/model/registry.py b/model/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f92219b218228baf09ef7ee596c0b1f360347d47 --- /dev/null +++ b/model/registry.py @@ -0,0 +1,149 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import sys +import re +import fnmatch +from collections import defaultdict +from copy import deepcopy + +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model default_cfgs + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) + if has_pretrained: + _model_has_pretrained.add(model_name) + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if module: + all_models = list(_module_to_models[module]) + else: + all_models = _model_entrypoints.keys() + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + if exclude_filters: + if not isinstance(exclude_filters, (tuple, list)): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + if pretrained: + models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) + return list(sorted(models, key=_natural_key)) + + +def is_model(model_name): + """ Check if a model name exists + """ + return model_name in _model_entrypoints + + +def model_entrypoint(model_name): + """Fetch a model entrypoint for specified model name + """ + return _model_entrypoints[model_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + assert isinstance(module_names, (tuple, list, set)) + return any(model_name in _module_to_models[n] for n in module_names) + + +def has_model_default_key(model_name, cfg_key): + """ Query model default_cfgs for existence of a specific key. + """ + if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: + return True + return False + + +def is_model_default_key(model_name, cfg_key): + """ Return truthy value for specified model default_cfg key, False if does not exist. + """ + if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): + return True + return False + + +def get_model_default_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if it doesn't exist. + """ + if model_name in _model_default_cfgs: + return _model_default_cfgs[model_name].get(cfg_key, None) + else: + return None + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained