|
|
|
|
|
|
|
from . import common |
|
import math |
|
|
|
from argparse import Namespace |
|
import torch.nn as nn |
|
from models import register |
|
import torch.nn.functional as F |
|
|
|
|
|
def make_model(args, parent=False): |
|
return FSRCNN(args) |
|
|
|
|
|
@register('FSRCNN') |
|
def FSRCNN(scale_ratio, rgb_range=1): |
|
args = Namespace() |
|
args.scale = [scale_ratio] |
|
args.n_colors = 3 |
|
args.rgb_range = rgb_range |
|
return FSRCNN(args) |
|
|
|
class FSRCNN(nn.Module): |
|
def __init__(self, args, conv=common.default_conv, d=56, s=12 * 3, m=8): |
|
super(FSRCNN, self).__init__() |
|
|
|
scale = args.scale[0] |
|
act = nn.PReLU() |
|
|
|
m_first_part = [] |
|
m_first_part.append(conv(args.n_colors, d, kernel_size=5)) |
|
m_first_part.append(act) |
|
self.first_part = nn.Sequential(*m_first_part) |
|
|
|
m_mid_part = [] |
|
m_mid_part.append(conv(d, s, kernel_size=1)) |
|
m_mid_part.append(act) |
|
for _ in range(m): |
|
m_mid_part.append(conv(s, s, kernel_size=3)) |
|
m_mid_part.append(act) |
|
m_mid_part.append(conv(s, d, kernel_size=1)) |
|
m_mid_part.append(act) |
|
self.mid_part = nn.Sequential(*m_mid_part) |
|
|
|
self.last_part = nn.ConvTranspose2d(d, args.n_colors, kernel_size=9, stride=scale, padding=9//2, |
|
output_padding=scale-1) |
|
|
|
|
|
|
|
|
|
def _initialize_weights(self): |
|
for m in self.first_part: |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) |
|
nn.init.zeros_(m.bias.data) |
|
for m in self.mid_part: |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) |
|
nn.init.zeros_(m.bias.data) |
|
nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) |
|
nn.init.zeros_(self.last_part.bias.data) |
|
|
|
def forward(self, x, out_size=None): |
|
x = self.first_part(x) |
|
x = self.mid_part(x) |
|
x = self.last_part(x) |
|
return x |
|
|
|
def load_state_dict(self, state_dict, strict=False): |
|
own_state = self.state_dict() |
|
for name, param in state_dict.items(): |
|
if name in own_state: |
|
if isinstance(param, nn.Parameter): |
|
param = param.data |
|
try: |
|
own_state[name].copy_(param) |
|
except Exception: |
|
if name.find('tail') >= 0: |
|
print('Replace pre-trained upsampler to new one...') |
|
else: |
|
raise RuntimeError('While copying the parameter named {}, ' |
|
'whose dimensions in the model are {} and ' |
|
'whose dimensions in the checkpoint are {}.' |
|
.format(name, own_state[name].size(), param.size())) |
|
elif strict: |
|
if name.find('tail') == -1: |
|
raise KeyError('unexpected key "{}" in state_dict' |
|
.format(name)) |
|
|
|
if strict: |
|
missing = set(own_state.keys()) - set(state_dict.keys()) |
|
if len(missing) > 0: |
|
raise KeyError('missing keys in state_dict: "{}"'.format(missing)) |