imsuperkong's picture
Upload 31 files
dc47947
raw
history blame
3.44 kB
import timm
import torch
import torch.nn as nn
import numpy as np
from .utils import activations, get_activation, Transpose
def forward_levit(pretrained, x):
pretrained.model.forward_features(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_1 = pretrained.act_postprocess1(layer_1)
layer_2 = pretrained.act_postprocess2(layer_2)
layer_3 = pretrained.act_postprocess3(layer_3)
return layer_1, layer_2, layer_3
def _make_levit_backbone(
model,
hooks=[3, 11, 21],
patch_grid=[14, 14]
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.activations = activations
patch_grid_size = np.array(patch_grid, dtype=int)
pretrained.act_postprocess1 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
)
pretrained.act_postprocess2 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
)
pretrained.act_postprocess3 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
)
return pretrained
class ConvTransposeNorm(nn.Sequential):
"""
Modification of
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
such that ConvTranspose2d is used instead of Conv2d.
"""
def __init__(
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c',
nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
self.add_module('bn', nn.BatchNorm2d(out_chs))
nn.init.constant_(self.bn.weight, bn_weight_init)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = nn.ConvTranspose2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def stem_b4_transpose(in_chs, out_chs, activation):
"""
Modification of
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
"""
return nn.Sequential(
ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
activation(),
ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
activation())
def _make_pretrained_levit_384(pretrained, hooks=None):
model = timm.create_model("levit_384", pretrained=pretrained)
hooks = [3, 11, 21] if hooks == None else hooks
return _make_levit_backbone(
model,
hooks=hooks
)