File size: 5,255 Bytes
36d9761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import copy
from torch import nn as nn
from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights

class DEResNet(nn.Module):
    """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore

    As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
    resnet arch works for image quality estimation.

    Args:
        num_in_ch (int): channel number of inputs. Default: 3.
        num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
        degradation_embed_size (int): embedding size of each degradation vector.
        degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
        num_feats (list): channel number of each stage.
        num_blocks (list): residual block of each stage.
        downscales (list): downscales of each stage.
    """

    def __init__(self,
                 num_in_ch=3,
                 num_degradation=2,
                 degradation_degree_actv='sigmoid',
                 num_feats=[64, 64, 64, 128],
                 num_blocks=[2, 2, 2, 2],
                 downscales=[1, 1, 2, 1]):
        super(DEResNet, self).__init__()

        assert isinstance(num_feats, list)
        assert isinstance(num_blocks, list)
        assert isinstance(downscales, list)
        assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)

        num_stage = len(num_feats)

        self.conv_first = nn.ModuleList()
        for _ in range(num_degradation):
            self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
        self.body = nn.ModuleList()
        for _ in range(num_degradation):
            body = list()
            for stage in range(num_stage):
                for _ in range(num_blocks[stage]):
                    body.append(ResidualBlockNoBN(num_feats[stage]))
                if downscales[stage] == 1:
                    if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
                        body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
                    continue
                elif downscales[stage] == 2:
                    body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
                else:
                    raise NotImplementedError
            self.body.append(nn.Sequential(*body))

        self.num_degradation = num_degradation
        self.fc_degree = nn.ModuleList()
        if degradation_degree_actv == 'sigmoid':
            actv = nn.Sigmoid
        elif degradation_degree_actv == 'tanh':
            actv = nn.Tanh
        else:
            raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
                                      f'{degradation_degree_actv} is not supported yet.')
        for _ in range(num_degradation):
            self.fc_degree.append(
                nn.Sequential(
                    nn.Linear(num_feats[-1], 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 1),
                    actv(),
                ))

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)

    def clone_module(self, module):
        new_module = copy.deepcopy(module)
        return new_module

    def average_parameters(self, modules):
        avg_module = self.clone_module(modules[0])
        for name, param in avg_module.named_parameters():
            avg_param = sum([mod.state_dict()[name].data for mod in modules]) / len(modules)
            param.data.copy_(avg_param)
        return avg_module

    def expand_degradation_modules(self, new_num_degradation):
        if new_num_degradation <= self.num_degradation:
            return
        initial_modules = [self.conv_first, self.body, self.fc_degree]

        for modules in initial_modules:
            avg_module = self.average_parameters(modules[:2])
            while len(modules) < new_num_degradation:
                modules.append(self.clone_module(avg_module))

    def load_and_expand_model(self, path, num_degradation):
        state_dict = torch.load(path, map_location=torch.device('cpu'))
        self.load_state_dict(state_dict, strict=True)
        
        self.expand_degradation_modules(num_degradation)
        self.num_degradation = num_degradation

    def load_model(self, path):
        state_dict = torch.load(path, map_location=torch.device('cpu'))
        self.load_state_dict(state_dict, strict=True)

    def set_train(self):
        self.conv_first.requires_grad_(True)
        self.fc_degree.requires_grad_(True)
        for n, _p in self.body.named_parameters():
            if "lora" in n:
                _p.requires_grad = True

    def forward(self, x):
        degrees = []
        for i in range(self.num_degradation):
            x_out = self.conv_first[i](x)
            feat = self.body[i](x_out)
            feat = self.avg_pool(feat)
            feat = feat.squeeze(-1).squeeze(-1)
            # for i in range(self.num_degradation):
            degrees.append(self.fc_degree[i](feat).squeeze(-1))
        return torch.stack(degrees, dim=1)