FunSR / models /transformer_neck.py
KyanChen's picture
add
02c5426
raw
history blame
3.86 kB
import torch
import torch.nn as nn
from models import register
from torch.nn import TransformerEncoderLayer, TransformerEncoder, LayerNorm
from torch.nn.init import xavier_uniform_
from einops import rearrange
# from mmcv.runner.base_module import BaseModule, ModuleList
# from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
# from mmcv.cnn.utils.weight_init import trunc_normal_
# from mmcv.cnn.bricks.registry import DROPOUT_LAYERS
@register('transformer_neck')
class TransformerNeck(nn.Module):
def __init__(self,
in_dim,
d_dim=256,
downsample=True,
has_pe=True,
has_norm=True,
class_token=True,
num_encoder_layers=3,
dim_feedforward=512,
drop_rate=0.1
):
super().__init__()
self.input_proj = nn.Conv2d(in_dim, d_dim, kernel_size=1)
self.downsample = downsample
if self.downsample:
self.sampler = nn.Conv2d(d_dim, d_dim, kernel_size=3, stride=2, padding=1)
encoder_layer = TransformerEncoderLayer(
d_model=d_dim,
nhead=8,
dim_feedforward=dim_feedforward,
dropout=drop_rate,
activation='gelu',
batch_first=True
)
if has_norm:
encoder_norm = LayerNorm(d_dim)
else:
encoder_norm = None
self.trans_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
if self.downsample:
self.uplayer = nn.Sequential(
nn.Conv2d(d_dim, d_dim*4, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.Conv2d(d_dim, d_dim, kernel_size=3, padding=1)
)
# for p in self.parameters():
# if p.dim() > 1:
# xavier_uniform_(p)
self.d_dim = d_dim
self.class_token = class_token
self.has_pe = has_pe
if self.class_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_dim))
if self.has_pe:
self.pos_embed = nn.Parameter(
torch.randn(1, d_dim, 24, 24) * 0.02
)
self.drop_after_pos = nn.Dropout(p=drop_rate)
def forward(self, x):
# x: List
x = [self.input_proj(x_tmp) for x_tmp in x]
if self.downsample:
x[0] = self.sampler(x[0])
B, C, H, W = x[0].shape
if self.has_pe:
assert W == self.pos_embed.shape[-1]
x_with_pe = []
for x_tmp in x:
b_tmp, c_tmp, h_tmp, w_tmp = x_tmp.shape
pe = nn.functional.interpolate(self.pos_embed, size=[h_tmp, w_tmp], mode='bicubic', align_corners=True)
x_tmp = x_tmp + pe
x_with_pe.append(x_tmp)
else:
x_with_pe = x
x_flatten = []
for x in x_with_pe:
x = rearrange(x, 'b c h w -> b (h w) c')
x_flatten.append(x)
x_flatten = torch.cat(x_flatten, dim=1)
if self.has_pe:
x_flatten = self.drop_after_pos(x_flatten)
if self.class_token:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x_flatten), dim=1)
else:
x = x_flatten
x = self.trans_encoder(x)
if self.class_token:
global_content = x[:, 0]
func_map = x[:, 1:(1 + H * W), :]
else:
global_content = x.mean(dim=1) # global pool without cls token
func_map = x[:, :(H * W), :]
func_map = rearrange(func_map, 'b (h w) c -> b c h w', h=H)
if self.downsample:
func_map = self.uplayer(func_map)
return global_content, func_map