File size: 3,488 Bytes
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

# code ref: https://github.com/yjn870/FSRCNN-pytorch/blob/master/models.py

from . import common
import math

from argparse import Namespace
import torch.nn as nn
from models import register
import torch.nn.functional as F


def make_model(args, parent=False):
    return FSRCNN(args)


@register('FSRCNN')
def FSRCNN(scale_ratio, rgb_range=1):
    args = Namespace()
    args.scale = [scale_ratio]
    args.n_colors = 3
    args.rgb_range = rgb_range
    return FSRCNN(args)

class FSRCNN(nn.Module):
    def __init__(self, args,  conv=common.default_conv, d=56, s=12 * 3, m=8):
        super(FSRCNN, self).__init__()

        scale = args.scale[0]
        act = nn.PReLU()

        m_first_part = []
        m_first_part.append(conv(args.n_colors, d, kernel_size=5))
        m_first_part.append(act)
        self.first_part = nn.Sequential(*m_first_part)

        m_mid_part = []
        m_mid_part.append(conv(d, s, kernel_size=1))
        m_mid_part.append(act)
        for _ in range(m):
            m_mid_part.append(conv(s, s, kernel_size=3))
            m_mid_part.append(act)
        m_mid_part.append(conv(s, d, kernel_size=1))
        m_mid_part.append(act)
        self.mid_part = nn.Sequential(*m_mid_part)

        self.last_part = nn.ConvTranspose2d(d, args.n_colors, kernel_size=9, stride=scale, padding=9//2,
                                            output_padding=scale-1)

        # self._initialize_weights()


    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x, out_size=None):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        return x

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))