FunSR / models /edsr.py
KyanChen's picture
add
02c5426
raw
history blame
6.61 kB
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch
import math
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
for p in self.parameters():
p.requires_grad = False
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
url = {
'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}
class EDSR(nn.Module):
def __init__(self, args, conv=default_conv):
super(EDSR, self).__init__()
self.args = args
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
if url_name in url:
self.url = url[url_name]
else:
self.url = None
self.sub_mean = MeanShift(args.rgb_range)
self.add_mean = MeanShift(args.rgb_range, sign=1)
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
if args.no_upsampling:
self.out_dim = n_feats
else:
self.out_dim = args.n_colors
# define tail module
m_tail = [
Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
]
self.tail = nn.Sequential(*m_tail)
self.load_state_dict('pretrained/'+self.url.split('/')[-1])
def forward(self, x):
#x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
if self.args.no_upsampling:
x = res
else:
x = self.tail(res)
#x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
state_dict = torch.load(state_dict, map_location='cpu')
own_state = self.state_dict()
print('loading pretrain model')
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
@register('edsr-baseline')
def make_edsr_baseline(n_resblocks=16, n_feats=64, res_scale=1,
scale=2, no_upsampling=False, rgb_range=1):
args = Namespace()
args.n_resblocks = n_resblocks
args.n_feats = n_feats
args.res_scale = res_scale
args.scale = [scale]
args.no_upsampling = no_upsampling
args.rgb_range = rgb_range
args.n_colors = 3
return EDSR(args)
@register('edsr')
def make_edsr(n_resblocks=32, n_feats=256, res_scale=0.1,
scale=2, no_upsampling=False, rgb_range=1):
args = Namespace()
args.n_resblocks = n_resblocks
args.n_feats = n_feats
args.res_scale = res_scale
args.scale = [scale]
args.no_upsampling = no_upsampling
args.rgb_range = rgb_range
args.n_colors = 3
return EDSR(args)