|
from . import common |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, repeat |
|
|
|
from models import register |
|
from .transformer import TransformerEncoder, TransformerDecoder |
|
from argparse import Namespace |
|
|
|
MIN_NUM_PATCHES = 12 |
|
|
|
|
|
def make_model(args, parent=False): |
|
return TransENet(args) |
|
|
|
|
|
class BasicModule(nn.Module): |
|
def __init__(self, conv, n_feat, kernel_size, block_type='basic', bias=True, |
|
bn=False, act=nn.ReLU(True)): |
|
super(BasicModule, self).__init__() |
|
|
|
self.block_type = block_type |
|
|
|
m_body = [] |
|
if block_type == 'basic': |
|
n_blocks = 10 |
|
m_body = [ |
|
common.BasicBlock(conv, n_feat, n_feat, kernel_size, bias=bias, bn=bn) |
|
|
|
for _ in range(n_blocks) |
|
] |
|
elif block_type == 'residual': |
|
n_blocks = 5 |
|
m_body = [ |
|
common.ResBlock(conv, n_feat, kernel_size) |
|
for _ in range(n_blocks) |
|
] |
|
else: |
|
print('Error: not support this type') |
|
self.body = nn.Sequential(*m_body) |
|
|
|
def forward(self, x): |
|
|
|
res = self.body(x) |
|
if self.block_type == 'basic': |
|
out = res + x |
|
elif self.block_type == 'residual': |
|
out = res |
|
|
|
return out |
|
|
|
|
|
@register('TransENet') |
|
def TransENet(scale_ratio, n_feats=64, rgb_range=1): |
|
args = Namespace() |
|
args.n_feats = n_feats |
|
args.scale = [scale_ratio] |
|
args.patch_size = 48 * args.scale[0] |
|
|
|
args.rgb_range = rgb_range |
|
args.n_colors = 3 |
|
args.en_depth = 6 |
|
args.de_depth = 1 |
|
return TransENet(args) |
|
|
|
|
|
class TransENet(nn.Module): |
|
|
|
def __init__(self, args, conv=common.default_conv): |
|
super(TransENet, self).__init__() |
|
|
|
self.args = args |
|
self.scale = args.scale[0] |
|
n_feats = args.n_feats |
|
kernel_size = 3 |
|
act = nn.ReLU(True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m_head = [ |
|
conv(args.n_colors, n_feats, kernel_size), |
|
] |
|
self.head = nn.Sequential(*m_head) |
|
|
|
|
|
self.feat_extrat_stage1 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) |
|
self.feat_extrat_stage2 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) |
|
self.feat_extrat_stage3 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) |
|
|
|
reduction = 4 |
|
self.stage1_conv1x1 = conv(n_feats, n_feats // reduction, 1) |
|
self.stage2_conv1x1 = conv(n_feats, n_feats // reduction, 1) |
|
self.stage3_conv1x1 = conv(n_feats, n_feats // reduction, 1) |
|
self.up_conv1x1 = conv(n_feats, n_feats // reduction, 1) |
|
self.span_conv1x1 = conv(n_feats // reduction, n_feats, 1) |
|
|
|
self.upsampler = common.Upsampler(conv, self.scale, n_feats, act=False) |
|
|
|
|
|
self.tail = conv(n_feats, args.n_colors, kernel_size) |
|
|
|
|
|
|
|
image_size = args.patch_size // self.scale |
|
patch_size = 4 |
|
dim = 512 |
|
en_depth = args.en_depth |
|
de_depth = args.de_depth |
|
heads = 6 |
|
mlp_dim = 512 |
|
channels = n_feats // reduction |
|
dim_head = 32 |
|
dropout = 0.0 |
|
|
|
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' |
|
num_patches = (image_size // patch_size) ** 2 |
|
|
|
patch_dim = channels * patch_size ** 2 |
|
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' |
|
|
|
self.patch_size = patch_size |
|
self.patch_to_embedding_low1 = nn.Linear(patch_dim, dim) |
|
self.patch_to_embedding_low2 = nn.Linear(patch_dim, dim) |
|
self.patch_to_embedding_low3 = nn.Linear(patch_dim, dim) |
|
self.patch_to_embedding_high = nn.Linear(patch_dim, dim) |
|
|
|
self.embedding_to_patch = nn.Linear(dim, patch_dim) |
|
|
|
self.encoder_stage1 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) |
|
self.encoder_stage2 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) |
|
self.encoder_stage3 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) |
|
self.encoder_up = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) |
|
|
|
self.decoder1 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) |
|
self.decoder2 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) |
|
self.decoder3 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) |
|
|
|
|
|
def forward(self, x, out_size=None): |
|
|
|
|
|
x = self.head(x) |
|
|
|
|
|
feat_stage1 = self.feat_extrat_stage1(x) |
|
feat_stage2 = self.feat_extrat_stage2(x) |
|
feat_stage3 = self.feat_extrat_stage3(x) |
|
feat_ups = self.upsampler(feat_stage3) |
|
|
|
feat_stage1 = self.stage1_conv1x1(feat_stage1) |
|
feat_stage2 = self.stage2_conv1x1(feat_stage2) |
|
feat_stage3 = self.stage3_conv1x1(feat_stage3) |
|
feat_ups = self.up_conv1x1(feat_ups) |
|
|
|
|
|
p = self.patch_size |
|
|
|
feat_stage1 = rearrange(feat_stage1, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) |
|
feat_stage2 = rearrange(feat_stage2, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) |
|
feat_stage3 = rearrange(feat_stage3, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) |
|
feat_ups = rearrange(feat_ups, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) |
|
|
|
feat_stage1 = self.patch_to_embedding_low1(feat_stage1) |
|
feat_stage2 = self.patch_to_embedding_low2(feat_stage2) |
|
feat_stage3 = self.patch_to_embedding_low3(feat_stage3) |
|
feat_ups = self.patch_to_embedding_high(feat_ups) |
|
|
|
|
|
feat_stage1 = self.encoder_stage1(feat_stage1) |
|
feat_stage2 = self.encoder_stage2(feat_stage2) |
|
feat_stage3 = self.encoder_stage3(feat_stage3) |
|
feat_ups = self.encoder_up(feat_ups) |
|
|
|
feat_ups = self.decoder3(feat_ups, feat_stage3) |
|
feat_ups = self.decoder2(feat_ups, feat_stage2) |
|
feat_ups = self.decoder1(feat_ups, feat_stage1) |
|
|
|
feat_ups = self.embedding_to_patch(feat_ups) |
|
feat_ups = rearrange(feat_ups, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=self.args.patch_size // p, p1=p, p2=p) |
|
|
|
feat_ups = self.span_conv1x1(feat_ups) |
|
|
|
x = self.tail(feat_ups) |
|
|
|
|
|
return x |
|
|
|
def load_state_dict(self, state_dict, strict=False): |
|
own_state = self.state_dict() |
|
for name, param in state_dict.items(): |
|
if name in own_state: |
|
if isinstance(param, nn.Parameter): |
|
param = param.data |
|
try: |
|
own_state[name].copy_(param) |
|
except Exception: |
|
if name.find('tail') >= 0: |
|
print('Replace pre-trained upsampler to new one...') |
|
else: |
|
raise RuntimeError('While copying the parameter named {}, ' |
|
'whose dimensions in the model are {} and ' |
|
'whose dimensions in the checkpoint are {}.' |
|
.format(name, own_state[name].size(), param.size())) |
|
elif strict: |
|
if name.find('tail') == -1: |
|
raise KeyError('unexpected key "{}" in state_dict' |
|
.format(name)) |
|
|
|
if strict: |
|
missing = set(own_state.keys()) - set(state_dict.keys()) |
|
if len(missing) > 0: |
|
raise KeyError('missing keys in state_dict: "{}"'.format(missing)) |
|
|
|
|
|
if __name__ == "__main__": |
|
from option import args |
|
model = TransENet(args) |
|
model.eval() |
|
input = torch.rand(1, 3, 48, 48) |
|
sr = model(input) |
|
print(sr.size()) |