FunSR / models /cnn_models /transenet.py
KyanChen's picture
add
02c5426
raw
history blame
8.4 kB
from . import common
import torch
import torch.nn as nn
from einops import rearrange, repeat
from models import register
from .transformer import TransformerEncoder, TransformerDecoder
from argparse import Namespace
MIN_NUM_PATCHES = 12
def make_model(args, parent=False):
return TransENet(args)
class BasicModule(nn.Module):
def __init__(self, conv, n_feat, kernel_size, block_type='basic', bias=True,
bn=False, act=nn.ReLU(True)):
super(BasicModule, self).__init__()
self.block_type = block_type
m_body = []
if block_type == 'basic':
n_blocks = 10
m_body = [
common.BasicBlock(conv, n_feat, n_feat, kernel_size, bias=bias, bn=bn)
# common.ResBlock(conv, n_feat, kernel_size)
for _ in range(n_blocks)
]
elif block_type == 'residual':
n_blocks = 5
m_body = [
common.ResBlock(conv, n_feat, kernel_size)
for _ in range(n_blocks)
]
else:
print('Error: not support this type')
self.body = nn.Sequential(*m_body)
def forward(self, x):
res = self.body(x)
if self.block_type == 'basic':
out = res + x
elif self.block_type == 'residual':
out = res
return out
@register('TransENet')
def TransENet(scale_ratio, n_feats=64, rgb_range=1):
args = Namespace()
args.n_feats = n_feats
args.scale = [scale_ratio]
args.patch_size = 48 * args.scale[0]
args.rgb_range = rgb_range
args.n_colors = 3
args.en_depth = 6
args.de_depth = 1
return TransENet(args)
class TransENet(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(TransENet, self).__init__()
self.args = args
self.scale = args.scale[0]
n_feats = args.n_feats
kernel_size = 3
act = nn.ReLU(True)
# rgb_mean = (0.4916, 0.4991, 0.4565) # UCMerced data
# rgb_std = (1.0, 1.0, 1.0)
#
# self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head body
m_head = [
conv(args.n_colors, n_feats, kernel_size),
]
self.head = nn.Sequential(*m_head)
# define main body
self.feat_extrat_stage1 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act)
self.feat_extrat_stage2 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act)
self.feat_extrat_stage3 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act)
reduction = 4
self.stage1_conv1x1 = conv(n_feats, n_feats // reduction, 1)
self.stage2_conv1x1 = conv(n_feats, n_feats // reduction, 1)
self.stage3_conv1x1 = conv(n_feats, n_feats // reduction, 1)
self.up_conv1x1 = conv(n_feats, n_feats // reduction, 1)
self.span_conv1x1 = conv(n_feats // reduction, n_feats, 1)
self.upsampler = common.Upsampler(conv, self.scale, n_feats, act=False)
# define tail body
self.tail = conv(n_feats, args.n_colors, kernel_size)
# self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
# define transformer
image_size = args.patch_size // self.scale
patch_size = 4
dim = 512
en_depth = args.en_depth
de_depth = args.de_depth
heads = 6
mlp_dim = 512
channels = n_feats // reduction
dim_head = 32
dropout = 0.0
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
self.patch_size = patch_size
self.patch_to_embedding_low1 = nn.Linear(patch_dim, dim)
self.patch_to_embedding_low2 = nn.Linear(patch_dim, dim)
self.patch_to_embedding_low3 = nn.Linear(patch_dim, dim)
self.patch_to_embedding_high = nn.Linear(patch_dim, dim)
self.embedding_to_patch = nn.Linear(dim, patch_dim)
self.encoder_stage1 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout)
self.encoder_stage2 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout)
self.encoder_stage3 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout)
self.encoder_up = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout)
self.decoder1 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout)
self.decoder2 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout)
self.decoder3 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout)
def forward(self, x, out_size=None):
# x = self.sub_mean(x)
x = self.head(x)
# feature extraction part
feat_stage1 = self.feat_extrat_stage1(x)
feat_stage2 = self.feat_extrat_stage2(x)
feat_stage3 = self.feat_extrat_stage3(x)
feat_ups = self.upsampler(feat_stage3)
feat_stage1 = self.stage1_conv1x1(feat_stage1)
feat_stage2 = self.stage2_conv1x1(feat_stage2)
feat_stage3 = self.stage3_conv1x1(feat_stage3)
feat_ups = self.up_conv1x1(feat_ups)
# transformer part:
p = self.patch_size
feat_stage1 = rearrange(feat_stage1, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
feat_stage2 = rearrange(feat_stage2, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
feat_stage3 = rearrange(feat_stage3, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
feat_ups = rearrange(feat_ups, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
feat_stage1 = self.patch_to_embedding_low1(feat_stage1)
feat_stage2 = self.patch_to_embedding_low2(feat_stage2)
feat_stage3 = self.patch_to_embedding_low3(feat_stage3)
feat_ups = self.patch_to_embedding_high(feat_ups)
# encoder
feat_stage1 = self.encoder_stage1(feat_stage1)
feat_stage2 = self.encoder_stage2(feat_stage2)
feat_stage3 = self.encoder_stage3(feat_stage3)
feat_ups = self.encoder_up(feat_ups)
feat_ups = self.decoder3(feat_ups, feat_stage3)
feat_ups = self.decoder2(feat_ups, feat_stage2)
feat_ups = self.decoder1(feat_ups, feat_stage1)
feat_ups = self.embedding_to_patch(feat_ups)
feat_ups = rearrange(feat_ups, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=self.args.patch_size // p, p1=p, p2=p)
feat_ups = self.span_conv1x1(feat_ups)
x = self.tail(feat_ups)
# x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
if __name__ == "__main__":
from option import args
model = TransENet(args)
model.eval()
input = torch.rand(1, 3, 48, 48)
sr = model(input)
print(sr.size())