zhangap commited on
Commit
36d9761
1 Parent(s): a0f7b49

Upload 213 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. assets/mm-realsr/de_net.pth +3 -0
  3. assets/pic/div2k_comparison.jpg +3 -0
  4. assets/pic/gradio.png +0 -0
  5. assets/pic/london2.jpg +3 -0
  6. assets/pic/main_framework.jpg +3 -0
  7. assets/pic/realsr_vis3.jpg +3 -0
  8. basicsr/__init__.py +12 -0
  9. basicsr/archs/__init__.py +24 -0
  10. basicsr/archs/__pycache__/__init__.cpython-310.pyc +0 -0
  11. basicsr/archs/__pycache__/arch_util.cpython-310.pyc +0 -0
  12. basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc +0 -0
  13. basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc +0 -0
  14. basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc +0 -0
  15. basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc +0 -0
  16. basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc +0 -0
  17. basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc +0 -0
  18. basicsr/archs/__pycache__/duf_arch.cpython-310.pyc +0 -0
  19. basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc +0 -0
  20. basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc +0 -0
  21. basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc +0 -0
  22. basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc +0 -0
  23. basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc +0 -0
  24. basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc +0 -0
  25. basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc +0 -0
  26. basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc +0 -0
  27. basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc +0 -0
  28. basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc +0 -0
  29. basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc +0 -0
  30. basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc +0 -0
  31. basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc +0 -0
  32. basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc +0 -0
  33. basicsr/archs/__pycache__/tof_arch.cpython-310.pyc +0 -0
  34. basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc +0 -0
  35. basicsr/archs/arch_util.py +355 -0
  36. basicsr/archs/basicvsr_arch.py +336 -0
  37. basicsr/archs/basicvsrpp_arch.py +417 -0
  38. basicsr/archs/degradat_arch.py +90 -0
  39. basicsr/archs/dfdnet_arch.py +169 -0
  40. basicsr/archs/dfdnet_util.py +162 -0
  41. basicsr/archs/discriminator_arch.py +150 -0
  42. basicsr/archs/duf_arch.py +276 -0
  43. basicsr/archs/ecbsr_arch.py +275 -0
  44. basicsr/archs/edsr_arch.py +61 -0
  45. basicsr/archs/edvr_arch.py +382 -0
  46. basicsr/archs/hifacegan_arch.py +260 -0
  47. basicsr/archs/hifacegan_util.py +255 -0
  48. basicsr/archs/inception.py +307 -0
  49. basicsr/archs/rcan_arch.py +135 -0
  50. basicsr/archs/ridnet_arch.py +180 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/pic/div2k_comparison.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/pic/london2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/pic/main_framework.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/pic/realsr_vis3.jpg filter=lfs diff=lfs merge=lfs -text
assets/mm-realsr/de_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6e77c1cb0dd51e01ef60bbf7b9d85b3f9399c3c1253889ddb73de1436231b1a
3
+ size 9424338
assets/pic/div2k_comparison.jpg ADDED

Git LFS Details

  • SHA256: d52daed3de6be211bda5e1c11b8d9352fee84fca53af53e10b3cddc86036db99
  • Pointer size: 132 Bytes
  • Size of remote file: 6.22 MB
assets/pic/gradio.png ADDED
assets/pic/london2.jpg ADDED

Git LFS Details

  • SHA256: 6055d986cff9fe82d97b60f2ed728f00fb81600aa69541a97db80aa73cd906fa
  • Pointer size: 132 Bytes
  • Size of remote file: 6.39 MB
assets/pic/main_framework.jpg ADDED

Git LFS Details

  • SHA256: f089c6f09a44300d36cb3c9b6b4905ca4f801d63ae1907fa83ab66aec70ec893
  • Pointer size: 132 Bytes
  • Size of remote file: 4.68 MB
assets/pic/realsr_vis3.jpg ADDED

Git LFS Details

  • SHA256: b07d0ffe838adf48958e3e910b467e1726d4a7a2f1623a9ae530f8db622fbb06
  • Pointer size: 132 Bytes
  • Size of remote file: 5.33 MB
basicsr/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .test import *
10
+ from .train import *
11
+ from .utils import *
12
+ # from .version import __gitsha__, __version__
basicsr/archs/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with '_arch.py'
12
+ arch_folder = osp.dirname(osp.abspath(__file__))
13
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
14
+ # import all the arch modules
15
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
16
+
17
+
18
+ def build_network(opt):
19
+ opt = deepcopy(opt)
20
+ network_type = opt.pop('type')
21
+ net = ARCH_REGISTRY.get(network_type)(**opt)
22
+ logger = get_root_logger()
23
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
24
+ return net
basicsr/archs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
basicsr/archs/__pycache__/arch_util.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc ADDED
Binary file (5.5 kB). View file
 
basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
basicsr/archs/__pycache__/duf_arch.cpython-310.pyc ADDED
Binary file (9.17 kB). View file
 
basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc ADDED
Binary file (8.39 kB). View file
 
basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc ADDED
Binary file (7.6 kB). View file
 
basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc ADDED
Binary file (8.5 kB). View file
 
basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc ADDED
Binary file (2.52 kB). View file
 
basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc ADDED
Binary file (25.2 kB). View file
 
basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc ADDED
Binary file (28.7 kB). View file
 
basicsr/archs/__pycache__/tof_arch.cpython-310.pyc ADDED
Binary file (6.29 kB). View file
 
basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+ class PixelShufflePack(nn.Module):
64
+ """Pixel Shuffle upsample layer.
65
+ Args:
66
+ in_channels (int): Number of input channels.
67
+ out_channels (int): Number of output channels.
68
+ scale_factor (int): Upsample ratio.
69
+ upsample_kernel (int): Kernel size of Conv layer to expand channels.
70
+ Returns:
71
+ Upsampled feature map.
72
+ """
73
+
74
+ def __init__(self, in_channels, out_channels, scale_factor,
75
+ upsample_kernel):
76
+ super().__init__()
77
+ self.in_channels = in_channels
78
+ self.out_channels = out_channels
79
+ self.scale_factor = scale_factor
80
+ self.upsample_kernel = upsample_kernel
81
+ self.upsample_conv = nn.Conv2d(
82
+ self.in_channels,
83
+ self.out_channels * scale_factor * scale_factor,
84
+ self.upsample_kernel,
85
+ padding=(self.upsample_kernel - 1) // 2)
86
+ self.init_weights()
87
+
88
+ def init_weights(self):
89
+ """Initialize weights for PixelShufflePack."""
90
+ default_init_weights(self, 1)
91
+
92
+ def forward(self, x):
93
+ """Forward function for PixelShufflePack.
94
+ Args:
95
+ x (Tensor): Input tensor with shape (n, c, h, w).
96
+ Returns:
97
+ Tensor: Forward results.
98
+ """
99
+ x = self.upsample_conv(x)
100
+ x = F.pixel_shuffle(x, self.scale_factor)
101
+ return x
102
+
103
+ class ResidualBlockNoBN(nn.Module):
104
+ """Residual block without BN.
105
+
106
+ Args:
107
+ num_feat (int): Channel number of intermediate features.
108
+ Default: 64.
109
+ res_scale (float): Residual scale. Default: 1.
110
+ pytorch_init (bool): If set to True, use pytorch default init,
111
+ otherwise, use default_init_weights. Default: False.
112
+ """
113
+
114
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
115
+ super(ResidualBlockNoBN, self).__init__()
116
+ self.res_scale = res_scale
117
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
118
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
119
+ self.relu = nn.ReLU()
120
+
121
+ if not pytorch_init:
122
+ default_init_weights([self.conv1, self.conv2], 0.1)
123
+
124
+ def forward(self, x):
125
+ identity = x
126
+ x = self.conv1(x)
127
+ x = self.relu(x)
128
+ out = self.conv2(x)
129
+ # out = self.conv2(self.relu(self.conv1(x)))
130
+ return identity + out * self.res_scale
131
+
132
+
133
+ class Upsample(nn.Sequential):
134
+ """Upsample module.
135
+
136
+ Args:
137
+ scale (int): Scale factor. Supported scales: 2^n and 3.
138
+ num_feat (int): Channel number of intermediate features.
139
+ """
140
+
141
+ def __init__(self, scale, num_feat):
142
+ m = []
143
+ if (scale & (scale - 1)) == 0: # scale = 2^n
144
+ for _ in range(int(math.log(scale, 2))):
145
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
146
+ m.append(nn.PixelShuffle(2))
147
+ elif scale == 3:
148
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
149
+ m.append(nn.PixelShuffle(3))
150
+ else:
151
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
152
+ super(Upsample, self).__init__(*m)
153
+
154
+
155
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
156
+ """Warp an image or feature map with optical flow.
157
+
158
+ Args:
159
+ x (Tensor): Tensor with size (n, c, h, w).
160
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
161
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
162
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
163
+ Default: 'zeros'.
164
+ align_corners (bool): Before pytorch 1.3, the default value is
165
+ align_corners=True. After pytorch 1.3, the default value is
166
+ align_corners=False. Here, we use the True as default.
167
+
168
+ Returns:
169
+ Tensor: Warped image or feature map.
170
+ """
171
+ assert x.size()[-2:] == flow.size()[1:3]
172
+ _, _, h, w = x.size()
173
+ # create mesh grid
174
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
175
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
176
+ grid.requires_grad = False
177
+
178
+ vgrid = grid + flow
179
+ # scale grid to [-1,1]
180
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
181
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
182
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
183
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
184
+
185
+ # TODO, what if align_corners=False
186
+ return output
187
+
188
+
189
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
190
+ """Resize a flow according to ratio or shape.
191
+
192
+ Args:
193
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
194
+ size_type (str): 'ratio' or 'shape'.
195
+ sizes (list[int | float]): the ratio for resizing or the final output
196
+ shape.
197
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
198
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
199
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
200
+ ratio > 1.0).
201
+ 2) The order of output_size should be [out_h, out_w].
202
+ interp_mode (str): The mode of interpolation for resizing.
203
+ Default: 'bilinear'.
204
+ align_corners (bool): Whether align corners. Default: False.
205
+
206
+ Returns:
207
+ Tensor: Resized flow.
208
+ """
209
+ _, _, flow_h, flow_w = flow.size()
210
+ if size_type == 'ratio':
211
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
212
+ elif size_type == 'shape':
213
+ output_h, output_w = sizes[0], sizes[1]
214
+ else:
215
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
216
+
217
+ input_flow = flow.clone()
218
+ ratio_h = output_h / flow_h
219
+ ratio_w = output_w / flow_w
220
+ input_flow[:, 0, :, :] *= ratio_w
221
+ input_flow[:, 1, :, :] *= ratio_h
222
+ resized_flow = F.interpolate(
223
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
224
+ return resized_flow
225
+
226
+
227
+ # TODO: may write a cpp file
228
+ def pixel_unshuffle(x, scale):
229
+ """ Pixel unshuffle.
230
+
231
+ Args:
232
+ x (Tensor): Input feature with shape (b, c, hh, hw).
233
+ scale (int): Downsample ratio.
234
+
235
+ Returns:
236
+ Tensor: the pixel unshuffled feature.
237
+ """
238
+ b, c, hh, hw = x.size()
239
+ out_channel = c * (scale**2)
240
+ assert hh % scale == 0 and hw % scale == 0
241
+ h = hh // scale
242
+ w = hw // scale
243
+ x_view = x.view(b, c, h, scale, w, scale)
244
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
245
+
246
+
247
+ class DCNv2Pack(ModulatedDeformConvPack):
248
+ """Modulated deformable conv for deformable alignment.
249
+
250
+ Different from the official DCNv2Pack, which generates offsets and masks
251
+ from the preceding features, this DCNv2Pack takes another different
252
+ features to generate offsets and masks.
253
+
254
+ ``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
255
+ """
256
+
257
+ def forward(self, x, feat):
258
+ out = self.conv_offset(feat)
259
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
260
+ offset = torch.cat((o1, o2), dim=1)
261
+ mask = torch.sigmoid(mask)
262
+
263
+ offset_absmean = torch.mean(torch.abs(offset))
264
+ if offset_absmean > 50:
265
+ logger = get_root_logger()
266
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
267
+
268
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
269
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
270
+ self.dilation, mask)
271
+ else:
272
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
273
+ self.dilation, self.groups, self.deformable_groups)
274
+
275
+
276
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
277
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
278
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
279
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
280
+ def norm_cdf(x):
281
+ # Computes standard normal cumulative distribution function
282
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
283
+
284
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
285
+ warnings.warn(
286
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
287
+ 'The distribution of values may be incorrect.',
288
+ stacklevel=2)
289
+
290
+ with torch.no_grad():
291
+ # Values are generated by using a truncated uniform distribution and
292
+ # then using the inverse CDF for the normal distribution.
293
+ # Get upper and lower cdf values
294
+ low = norm_cdf((a - mean) / std)
295
+ up = norm_cdf((b - mean) / std)
296
+
297
+ # Uniformly fill tensor with values from [low, up], then translate to
298
+ # [2l-1, 2u-1].
299
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
300
+
301
+ # Use inverse cdf transform for normal distribution to get truncated
302
+ # standard normal
303
+ tensor.erfinv_()
304
+
305
+ # Transform to proper mean, std
306
+ tensor.mul_(std * math.sqrt(2.))
307
+ tensor.add_(mean)
308
+
309
+ # Clamp to ensure it's in the proper range
310
+ tensor.clamp_(min=a, max=b)
311
+ return tensor
312
+
313
+
314
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
315
+ r"""Fills the input Tensor with values drawn from a truncated
316
+ normal distribution.
317
+
318
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
319
+
320
+ The values are effectively drawn from the
321
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
322
+ with values outside :math:`[a, b]` redrawn until they are within
323
+ the bounds. The method used for generating the random values works
324
+ best when :math:`a \leq \text{mean} \leq b`.
325
+
326
+ Args:
327
+ tensor: an n-dimensional `torch.Tensor`
328
+ mean: the mean of the normal distribution
329
+ std: the standard deviation of the normal distribution
330
+ a: the minimum cutoff value
331
+ b: the maximum cutoff value
332
+
333
+ Examples:
334
+ >>> w = torch.empty(3, 5)
335
+ >>> nn.init.trunc_normal_(w)
336
+ """
337
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
338
+
339
+
340
+ # From PyTorch
341
+ def _ntuple(n):
342
+
343
+ def parse(x):
344
+ if isinstance(x, collections.abc.Iterable):
345
+ return x
346
+ return tuple(repeat(x, n))
347
+
348
+ return parse
349
+
350
+
351
+ to_1tuple = _ntuple(1)
352
+ to_2tuple = _ntuple(2)
353
+ to_3tuple = _ntuple(3)
354
+ to_4tuple = _ntuple(4)
355
+ to_ntuple = _ntuple
basicsr/archs/basicvsr_arch.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
7
+ from .edvr_arch import PCDAlignment, TSAFusion
8
+ from .spynet_arch import SpyNet
9
+
10
+
11
+ @ARCH_REGISTRY.register()
12
+ class BasicVSR(nn.Module):
13
+ """A recurrent network for video SR. Now only x4 is supported.
14
+
15
+ Args:
16
+ num_feat (int): Number of channels. Default: 64.
17
+ num_block (int): Number of residual blocks for each branch. Default: 15
18
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
19
+ """
20
+
21
+ def __init__(self, num_feat=64, num_block=15, spynet_path=None):
22
+ super().__init__()
23
+ self.num_feat = num_feat
24
+
25
+ # alignment
26
+ self.spynet = SpyNet(spynet_path)
27
+
28
+ # propagation
29
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
30
+ self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
31
+
32
+ # reconstruction
33
+ self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
34
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
35
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
36
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
37
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
38
+
39
+ self.pixel_shuffle = nn.PixelShuffle(2)
40
+
41
+ # activation functions
42
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
43
+
44
+ def get_flow(self, x):
45
+ b, n, c, h, w = x.size()
46
+
47
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
48
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
49
+
50
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
51
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
52
+
53
+ return flows_forward, flows_backward
54
+
55
+ def forward(self, x):
56
+ """Forward function of BasicVSR.
57
+
58
+ Args:
59
+ x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
60
+ """
61
+ flows_forward, flows_backward = self.get_flow(x)
62
+ b, n, _, h, w = x.size()
63
+
64
+ # backward branch
65
+ out_l = []
66
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
67
+ for i in range(n - 1, -1, -1):
68
+ x_i = x[:, i, :, :, :]
69
+ if i < n - 1:
70
+ flow = flows_backward[:, i, :, :, :]
71
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
72
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
73
+ feat_prop = self.backward_trunk(feat_prop)
74
+ out_l.insert(0, feat_prop)
75
+
76
+ # forward branch
77
+ feat_prop = torch.zeros_like(feat_prop)
78
+ for i in range(0, n):
79
+ x_i = x[:, i, :, :, :]
80
+ if i > 0:
81
+ flow = flows_forward[:, i - 1, :, :, :]
82
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
83
+
84
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
85
+ feat_prop = self.forward_trunk(feat_prop)
86
+
87
+ # upsample
88
+ out = torch.cat([out_l[i], feat_prop], dim=1)
89
+ out = self.lrelu(self.fusion(out))
90
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
91
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
92
+ out = self.lrelu(self.conv_hr(out))
93
+ out = self.conv_last(out)
94
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
95
+ out += base
96
+ out_l[i] = out
97
+
98
+ return torch.stack(out_l, dim=1)
99
+
100
+
101
+ class ConvResidualBlocks(nn.Module):
102
+ """Conv and residual block used in BasicVSR.
103
+
104
+ Args:
105
+ num_in_ch (int): Number of input channels. Default: 3.
106
+ num_out_ch (int): Number of output channels. Default: 64.
107
+ num_block (int): Number of residual blocks. Default: 15.
108
+ """
109
+
110
+ def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
111
+ super().__init__()
112
+ self.main = nn.Sequential(
113
+ nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
114
+ make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
115
+
116
+ def forward(self, fea):
117
+ return self.main(fea)
118
+
119
+
120
+ @ARCH_REGISTRY.register()
121
+ class IconVSR(nn.Module):
122
+ """IconVSR, proposed also in the BasicVSR paper.
123
+
124
+ Args:
125
+ num_feat (int): Number of channels. Default: 64.
126
+ num_block (int): Number of residual blocks for each branch. Default: 15.
127
+ keyframe_stride (int): Keyframe stride. Default: 5.
128
+ temporal_padding (int): Temporal padding. Default: 2.
129
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
130
+ edvr_path (str): Path to the pretrained EDVR model. Default: None.
131
+ """
132
+
133
+ def __init__(self,
134
+ num_feat=64,
135
+ num_block=15,
136
+ keyframe_stride=5,
137
+ temporal_padding=2,
138
+ spynet_path=None,
139
+ edvr_path=None):
140
+ super().__init__()
141
+
142
+ self.num_feat = num_feat
143
+ self.temporal_padding = temporal_padding
144
+ self.keyframe_stride = keyframe_stride
145
+
146
+ # keyframe_branch
147
+ self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
148
+ # alignment
149
+ self.spynet = SpyNet(spynet_path)
150
+
151
+ # propagation
152
+ self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
153
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
154
+
155
+ self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
156
+ self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
157
+
158
+ # reconstruction
159
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
160
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
161
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
162
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
163
+
164
+ self.pixel_shuffle = nn.PixelShuffle(2)
165
+
166
+ # activation functions
167
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
168
+
169
+ def pad_spatial(self, x):
170
+ """Apply padding spatially.
171
+
172
+ Since the PCD module in EDVR requires that the resolution is a multiple
173
+ of 4, we apply padding to the input LR images if their resolution is
174
+ not divisible by 4.
175
+
176
+ Args:
177
+ x (Tensor): Input LR sequence with shape (n, t, c, h, w).
178
+ Returns:
179
+ Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
180
+ """
181
+ n, t, c, h, w = x.size()
182
+
183
+ pad_h = (4 - h % 4) % 4
184
+ pad_w = (4 - w % 4) % 4
185
+
186
+ # padding
187
+ x = x.view(-1, c, h, w)
188
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
189
+
190
+ return x.view(n, t, c, h + pad_h, w + pad_w)
191
+
192
+ def get_flow(self, x):
193
+ b, n, c, h, w = x.size()
194
+
195
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
196
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
197
+
198
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
199
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
200
+
201
+ return flows_forward, flows_backward
202
+
203
+ def get_keyframe_feature(self, x, keyframe_idx):
204
+ if self.temporal_padding == 2:
205
+ x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
206
+ elif self.temporal_padding == 3:
207
+ x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
208
+ x = torch.cat(x, dim=1)
209
+
210
+ num_frames = 2 * self.temporal_padding + 1
211
+ feats_keyframe = {}
212
+ for i in keyframe_idx:
213
+ feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
214
+ return feats_keyframe
215
+
216
+ def forward(self, x):
217
+ b, n, _, h_input, w_input = x.size()
218
+
219
+ x = self.pad_spatial(x)
220
+ h, w = x.shape[3:]
221
+
222
+ keyframe_idx = list(range(0, n, self.keyframe_stride))
223
+ if keyframe_idx[-1] != n - 1:
224
+ keyframe_idx.append(n - 1) # last frame is a keyframe
225
+
226
+ # compute flow and keyframe features
227
+ flows_forward, flows_backward = self.get_flow(x)
228
+ feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
229
+
230
+ # backward branch
231
+ out_l = []
232
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
233
+ for i in range(n - 1, -1, -1):
234
+ x_i = x[:, i, :, :, :]
235
+ if i < n - 1:
236
+ flow = flows_backward[:, i, :, :, :]
237
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
238
+ if i in keyframe_idx:
239
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
240
+ feat_prop = self.backward_fusion(feat_prop)
241
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
242
+ feat_prop = self.backward_trunk(feat_prop)
243
+ out_l.insert(0, feat_prop)
244
+
245
+ # forward branch
246
+ feat_prop = torch.zeros_like(feat_prop)
247
+ for i in range(0, n):
248
+ x_i = x[:, i, :, :, :]
249
+ if i > 0:
250
+ flow = flows_forward[:, i - 1, :, :, :]
251
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
252
+ if i in keyframe_idx:
253
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
254
+ feat_prop = self.forward_fusion(feat_prop)
255
+
256
+ feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
257
+ feat_prop = self.forward_trunk(feat_prop)
258
+
259
+ # upsample
260
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
261
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
262
+ out = self.lrelu(self.conv_hr(out))
263
+ out = self.conv_last(out)
264
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
265
+ out += base
266
+ out_l[i] = out
267
+
268
+ return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
269
+
270
+
271
+ class EDVRFeatureExtractor(nn.Module):
272
+ """EDVR feature extractor used in IconVSR.
273
+
274
+ Args:
275
+ num_input_frame (int): Number of input frames.
276
+ num_feat (int): Number of feature channels
277
+ load_path (str): Path to the pretrained weights of EDVR. Default: None.
278
+ """
279
+
280
+ def __init__(self, num_input_frame, num_feat, load_path):
281
+
282
+ super(EDVRFeatureExtractor, self).__init__()
283
+
284
+ self.center_frame_idx = num_input_frame // 2
285
+
286
+ # extract pyramid features
287
+ self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
288
+ self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
289
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
290
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
291
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
292
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
293
+
294
+ # pcd and tsa module
295
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
296
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
297
+
298
+ # activation function
299
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
300
+
301
+ if load_path:
302
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
303
+
304
+ def forward(self, x):
305
+ b, n, c, h, w = x.size()
306
+
307
+ # extract features for each frame
308
+ # L1
309
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
310
+ feat_l1 = self.feature_extraction(feat_l1)
311
+ # L2
312
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
313
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
314
+ # L3
315
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
316
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
317
+
318
+ feat_l1 = feat_l1.view(b, n, -1, h, w)
319
+ feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
320
+ feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
321
+
322
+ # PCD alignment
323
+ ref_feat_l = [ # reference feature list
324
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
325
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
326
+ ]
327
+ aligned_feat = []
328
+ for i in range(n):
329
+ nbr_feat_l = [ # neighboring feature list
330
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
331
+ ]
332
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
333
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
334
+
335
+ # TSA fusion
336
+ return self.fusion(aligned_feat)
basicsr/archs/basicvsrpp_arch.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import warnings
6
+
7
+ from basicsr.archs.arch_util import flow_warp
8
+ from basicsr.archs.basicvsr_arch import ConvResidualBlocks
9
+ from basicsr.archs.spynet_arch import SpyNet
10
+ from basicsr.ops.dcn import ModulatedDeformConvPack
11
+ from basicsr.utils.registry import ARCH_REGISTRY
12
+
13
+
14
+ @ARCH_REGISTRY.register()
15
+ class BasicVSRPlusPlus(nn.Module):
16
+ """BasicVSR++ network structure.
17
+
18
+ Support either x4 upsampling or same size output. Since DCN is used in this
19
+ model, it can only be used with CUDA enabled. If CUDA is not enabled,
20
+ feature alignment will be skipped. Besides, we adopt the official DCN
21
+ implementation and the version of torch need to be higher than 1.9.
22
+
23
+ ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
24
+
25
+ Args:
26
+ mid_channels (int, optional): Channel number of the intermediate
27
+ features. Default: 64.
28
+ num_blocks (int, optional): The number of residual blocks in each
29
+ propagation branch. Default: 7.
30
+ max_residue_magnitude (int): The maximum magnitude of the offset
31
+ residue (Eq. 6 in paper). Default: 10.
32
+ is_low_res_input (bool, optional): Whether the input is low-resolution
33
+ or not. If False, the output resolution is equal to the input
34
+ resolution. Default: True.
35
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
36
+ cpu_cache_length (int, optional): When the length of sequence is larger
37
+ than this value, the intermediate features are sent to CPU. This
38
+ saves GPU memory, but slows down the inference speed. You can
39
+ increase this number if you have a GPU with large memory.
40
+ Default: 100.
41
+ """
42
+
43
+ def __init__(self,
44
+ mid_channels=64,
45
+ num_blocks=7,
46
+ max_residue_magnitude=10,
47
+ is_low_res_input=True,
48
+ spynet_path=None,
49
+ cpu_cache_length=100):
50
+
51
+ super().__init__()
52
+ self.mid_channels = mid_channels
53
+ self.is_low_res_input = is_low_res_input
54
+ self.cpu_cache_length = cpu_cache_length
55
+
56
+ # optical flow
57
+ self.spynet = SpyNet(spynet_path)
58
+
59
+ # feature extraction module
60
+ if is_low_res_input:
61
+ self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
62
+ else:
63
+ self.feat_extract = nn.Sequential(
64
+ nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
65
+ nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
66
+ ConvResidualBlocks(mid_channels, mid_channels, 5))
67
+
68
+ # propagation branches
69
+ self.deform_align = nn.ModuleDict()
70
+ self.backbone = nn.ModuleDict()
71
+ modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
72
+ for i, module in enumerate(modules):
73
+ if torch.cuda.is_available():
74
+ self.deform_align[module] = SecondOrderDeformableAlignment(
75
+ 2 * mid_channels,
76
+ mid_channels,
77
+ 3,
78
+ padding=1,
79
+ deformable_groups=16,
80
+ max_residue_magnitude=max_residue_magnitude)
81
+ self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
82
+
83
+ # upsampling module
84
+ self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
85
+
86
+ self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
87
+ self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
88
+
89
+ self.pixel_shuffle = nn.PixelShuffle(2)
90
+
91
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
92
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
93
+ self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
94
+
95
+ # activation function
96
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
97
+
98
+ # check if the sequence is augmented by flipping
99
+ self.is_mirror_extended = False
100
+
101
+ if len(self.deform_align) > 0:
102
+ self.is_with_alignment = True
103
+ else:
104
+ self.is_with_alignment = False
105
+ warnings.warn('Deformable alignment module is not added. '
106
+ 'Probably your CUDA is not configured correctly. DCN can only '
107
+ 'be used with CUDA enabled. Alignment is skipped now.')
108
+
109
+ def check_if_mirror_extended(self, lqs):
110
+ """Check whether the input is a mirror-extended sequence.
111
+
112
+ If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
113
+
114
+ Args:
115
+ lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
116
+ """
117
+
118
+ if lqs.size(1) % 2 == 0:
119
+ lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
120
+ if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
121
+ self.is_mirror_extended = True
122
+
123
+ def compute_flow(self, lqs):
124
+ """Compute optical flow using SPyNet for feature alignment.
125
+
126
+ Note that if the input is an mirror-extended sequence, 'flows_forward'
127
+ is not needed, since it is equal to 'flows_backward.flip(1)'.
128
+
129
+ Args:
130
+ lqs (tensor): Input low quality (LQ) sequence with
131
+ shape (n, t, c, h, w).
132
+
133
+ Return:
134
+ tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
135
+ (current to previous). 'flows_backward' corresponds to the flows used for backward-time \
136
+ propagation (current to next).
137
+ """
138
+
139
+ n, t, c, h, w = lqs.size()
140
+ lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
141
+ lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
142
+
143
+ flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
144
+
145
+ if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
146
+ flows_forward = flows_backward.flip(1)
147
+ else:
148
+ flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
149
+
150
+ if self.cpu_cache:
151
+ flows_backward = flows_backward.cpu()
152
+ flows_forward = flows_forward.cpu()
153
+
154
+ return flows_forward, flows_backward
155
+
156
+ def propagate(self, feats, flows, module_name):
157
+ """Propagate the latent features throughout the sequence.
158
+
159
+ Args:
160
+ feats dict(list[tensor]): Features from previous branches. Each
161
+ component is a list of tensors with shape (n, c, h, w).
162
+ flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
163
+ module_name (str): The name of the propgation branches. Can either
164
+ be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
165
+
166
+ Return:
167
+ dict(list[tensor]): A dictionary containing all the propagated \
168
+ features. Each key in the dictionary corresponds to a \
169
+ propagation branch, which is represented by a list of tensors.
170
+ """
171
+
172
+ n, t, _, h, w = flows.size()
173
+
174
+ frame_idx = range(0, t + 1)
175
+ flow_idx = range(-1, t)
176
+ mapping_idx = list(range(0, len(feats['spatial'])))
177
+ mapping_idx += mapping_idx[::-1]
178
+
179
+ if 'backward' in module_name:
180
+ frame_idx = frame_idx[::-1]
181
+ flow_idx = frame_idx
182
+
183
+ feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
184
+ for i, idx in enumerate(frame_idx):
185
+ feat_current = feats['spatial'][mapping_idx[idx]]
186
+ if self.cpu_cache:
187
+ feat_current = feat_current.cuda()
188
+ feat_prop = feat_prop.cuda()
189
+ # second-order deformable alignment
190
+ if i > 0 and self.is_with_alignment:
191
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
192
+ if self.cpu_cache:
193
+ flow_n1 = flow_n1.cuda()
194
+
195
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
196
+
197
+ # initialize second-order features
198
+ feat_n2 = torch.zeros_like(feat_prop)
199
+ flow_n2 = torch.zeros_like(flow_n1)
200
+ cond_n2 = torch.zeros_like(cond_n1)
201
+
202
+ if i > 1: # second-order features
203
+ feat_n2 = feats[module_name][-2]
204
+ if self.cpu_cache:
205
+ feat_n2 = feat_n2.cuda()
206
+
207
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
208
+ if self.cpu_cache:
209
+ flow_n2 = flow_n2.cuda()
210
+
211
+ flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
212
+ cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
213
+
214
+ # flow-guided deformable convolution
215
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
216
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
217
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
218
+
219
+ # concatenate and residual blocks
220
+ feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
221
+ if self.cpu_cache:
222
+ feat = [f.cuda() for f in feat]
223
+
224
+ feat = torch.cat(feat, dim=1)
225
+ feat_prop = feat_prop + self.backbone[module_name](feat)
226
+ feats[module_name].append(feat_prop)
227
+
228
+ if self.cpu_cache:
229
+ feats[module_name][-1] = feats[module_name][-1].cpu()
230
+ torch.cuda.empty_cache()
231
+
232
+ if 'backward' in module_name:
233
+ feats[module_name] = feats[module_name][::-1]
234
+
235
+ return feats
236
+
237
+ def upsample(self, lqs, feats):
238
+ """Compute the output image given the features.
239
+
240
+ Args:
241
+ lqs (tensor): Input low quality (LQ) sequence with
242
+ shape (n, t, c, h, w).
243
+ feats (dict): The features from the propagation branches.
244
+
245
+ Returns:
246
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
247
+ """
248
+
249
+ outputs = []
250
+ num_outputs = len(feats['spatial'])
251
+
252
+ mapping_idx = list(range(0, num_outputs))
253
+ mapping_idx += mapping_idx[::-1]
254
+
255
+ for i in range(0, lqs.size(1)):
256
+ hr = [feats[k].pop(0) for k in feats if k != 'spatial']
257
+ hr.insert(0, feats['spatial'][mapping_idx[i]])
258
+ hr = torch.cat(hr, dim=1)
259
+ if self.cpu_cache:
260
+ hr = hr.cuda()
261
+
262
+ hr = self.reconstruction(hr)
263
+ hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
264
+ hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
265
+ hr = self.lrelu(self.conv_hr(hr))
266
+ hr = self.conv_last(hr)
267
+ if self.is_low_res_input:
268
+ hr += self.img_upsample(lqs[:, i, :, :, :])
269
+ else:
270
+ hr += lqs[:, i, :, :, :]
271
+
272
+ if self.cpu_cache:
273
+ hr = hr.cpu()
274
+ torch.cuda.empty_cache()
275
+
276
+ outputs.append(hr)
277
+
278
+ return torch.stack(outputs, dim=1)
279
+
280
+ def forward(self, lqs):
281
+ """Forward function for BasicVSR++.
282
+
283
+ Args:
284
+ lqs (tensor): Input low quality (LQ) sequence with
285
+ shape (n, t, c, h, w).
286
+
287
+ Returns:
288
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
289
+ """
290
+
291
+ n, t, c, h, w = lqs.size()
292
+
293
+ # whether to cache the features in CPU
294
+ self.cpu_cache = True if t > self.cpu_cache_length else False
295
+
296
+ if self.is_low_res_input:
297
+ lqs_downsample = lqs.clone()
298
+ else:
299
+ lqs_downsample = F.interpolate(
300
+ lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
301
+
302
+ # check whether the input is an extended sequence
303
+ self.check_if_mirror_extended(lqs)
304
+
305
+ feats = {}
306
+ # compute spatial features
307
+ if self.cpu_cache:
308
+ feats['spatial'] = []
309
+ for i in range(0, t):
310
+ feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
311
+ feats['spatial'].append(feat)
312
+ torch.cuda.empty_cache()
313
+ else:
314
+ feats_ = self.feat_extract(lqs.view(-1, c, h, w))
315
+ h, w = feats_.shape[2:]
316
+ feats_ = feats_.view(n, t, -1, h, w)
317
+ feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
318
+
319
+ # compute optical flow using the low-res inputs
320
+ assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
321
+ 'The height and width of low-res inputs must be at least 64, '
322
+ f'but got {h} and {w}.')
323
+ flows_forward, flows_backward = self.compute_flow(lqs_downsample)
324
+
325
+ # feature propgation
326
+ for iter_ in [1, 2]:
327
+ for direction in ['backward', 'forward']:
328
+ module = f'{direction}_{iter_}'
329
+
330
+ feats[module] = []
331
+
332
+ if direction == 'backward':
333
+ flows = flows_backward
334
+ elif flows_forward is not None:
335
+ flows = flows_forward
336
+ else:
337
+ flows = flows_backward.flip(1)
338
+
339
+ feats = self.propagate(feats, flows, module)
340
+ if self.cpu_cache:
341
+ del flows
342
+ torch.cuda.empty_cache()
343
+
344
+ return self.upsample(lqs, feats)
345
+
346
+
347
+ class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
348
+ """Second-order deformable alignment module.
349
+
350
+ Args:
351
+ in_channels (int): Same as nn.Conv2d.
352
+ out_channels (int): Same as nn.Conv2d.
353
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
354
+ stride (int or tuple[int]): Same as nn.Conv2d.
355
+ padding (int or tuple[int]): Same as nn.Conv2d.
356
+ dilation (int or tuple[int]): Same as nn.Conv2d.
357
+ groups (int): Same as nn.Conv2d.
358
+ bias (bool or str): If specified as `auto`, it will be decided by the
359
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
360
+ False.
361
+ max_residue_magnitude (int): The maximum magnitude of the offset
362
+ residue (Eq. 6 in paper). Default: 10.
363
+ """
364
+
365
+ def __init__(self, *args, **kwargs):
366
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
367
+
368
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
369
+
370
+ self.conv_offset = nn.Sequential(
371
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
372
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
373
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
374
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
375
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
376
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
377
+ nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
378
+ )
379
+
380
+ self.init_offset()
381
+
382
+ def init_offset(self):
383
+
384
+ def _constant_init(module, val, bias=0):
385
+ if hasattr(module, 'weight') and module.weight is not None:
386
+ nn.init.constant_(module.weight, val)
387
+ if hasattr(module, 'bias') and module.bias is not None:
388
+ nn.init.constant_(module.bias, bias)
389
+
390
+ _constant_init(self.conv_offset[-1], val=0, bias=0)
391
+
392
+ def forward(self, x, extra_feat, flow_1, flow_2):
393
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
394
+ out = self.conv_offset(extra_feat)
395
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
396
+
397
+ # offset
398
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
399
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
400
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
401
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
402
+ offset = torch.cat([offset_1, offset_2], dim=1)
403
+
404
+ # mask
405
+ mask = torch.sigmoid(mask)
406
+
407
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
408
+ self.dilation, mask)
409
+
410
+
411
+ # if __name__ == '__main__':
412
+ # spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
413
+ # model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
414
+ # input = torch.rand(1, 2, 3, 64, 64).cuda()
415
+ # output = model(input)
416
+ # print('===================')
417
+ # print(output.shape)
basicsr/archs/degradat_arch.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+
3
+ from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+
6
+ @ARCH_REGISTRY.register()
7
+ class DEResNet(nn.Module):
8
+ """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
9
+ As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
10
+ resnet arch works for image quality estimation.
11
+ Args:
12
+ num_in_ch (int): channel number of inputs. Default: 3.
13
+ num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
14
+ degradation_embed_size (int): embedding size of each degradation vector.
15
+ degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
16
+ num_feats (list): channel number of each stage.
17
+ num_blocks (list): residual block of each stage.
18
+ downscales (list): downscales of each stage.
19
+ """
20
+
21
+ def __init__(self,
22
+ num_in_ch=3,
23
+ num_degradation=2,
24
+ degradation_degree_actv='sigmoid',
25
+ num_feats=(64, 128, 256, 512),
26
+ num_blocks=(2, 2, 2, 2),
27
+ downscales=(2, 2, 2, 1)):
28
+ super(DEResNet, self).__init__()
29
+
30
+ assert isinstance(num_feats, list)
31
+ assert isinstance(num_blocks, list)
32
+ assert isinstance(downscales, list)
33
+ assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
34
+
35
+ num_stage = len(num_feats)
36
+
37
+ self.conv_first = nn.ModuleList()
38
+ for _ in range(num_degradation):
39
+ self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
40
+ self.body = nn.ModuleList()
41
+ for _ in range(num_degradation):
42
+ body = list()
43
+ for stage in range(num_stage):
44
+ for _ in range(num_blocks[stage]):
45
+ body.append(ResidualBlockNoBN(num_feats[stage]))
46
+ if downscales[stage] == 1:
47
+ if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
48
+ body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
49
+ continue
50
+ elif downscales[stage] == 2:
51
+ body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
52
+ else:
53
+ raise NotImplementedError
54
+ self.body.append(nn.Sequential(*body))
55
+
56
+ # self.body = nn.Sequential(*body)
57
+
58
+ self.num_degradation = num_degradation
59
+ self.fc_degree = nn.ModuleList()
60
+ if degradation_degree_actv == 'sigmoid':
61
+ actv = nn.Sigmoid
62
+ elif degradation_degree_actv == 'tanh':
63
+ actv = nn.Tanh
64
+ else:
65
+ raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
66
+ f'{degradation_degree_actv} is not supported yet.')
67
+ for _ in range(num_degradation):
68
+ self.fc_degree.append(
69
+ nn.Sequential(
70
+ nn.Linear(num_feats[-1], 512),
71
+ nn.ReLU(inplace=True),
72
+ nn.Linear(512, 1),
73
+ actv(),
74
+ ))
75
+
76
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
77
+
78
+ default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
79
+
80
+ def forward(self, x):
81
+ degrees = []
82
+ for i in range(self.num_degradation):
83
+ x_out = self.conv_first[i](x)
84
+ feat = self.body[i](x_out)
85
+ feat = self.avg_pool(feat)
86
+ feat = feat.squeeze(-1).squeeze(-1)
87
+ # for i in range(self.num_degradation):
88
+ degrees.append(self.fc_degree[i](feat).squeeze(-1))
89
+
90
+ return degrees
basicsr/archs/dfdnet_arch.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.spectral_norm import spectral_norm
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+ from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
9
+ from .vgg_arch import VGGFeatureExtractor
10
+
11
+
12
+ class SFTUpBlock(nn.Module):
13
+ """Spatial feature transform (SFT) with upsampling block.
14
+
15
+ Args:
16
+ in_channel (int): Number of input channels.
17
+ out_channel (int): Number of output channels.
18
+ kernel_size (int): Kernel size in convolutions. Default: 3.
19
+ padding (int): Padding in convolutions. Default: 1.
20
+ """
21
+
22
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
23
+ super(SFTUpBlock, self).__init__()
24
+ self.conv1 = nn.Sequential(
25
+ Blur(in_channel),
26
+ spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
27
+ nn.LeakyReLU(0.04, True),
28
+ # The official codes use two LeakyReLU here, so 0.04 for equivalent
29
+ )
30
+ self.convup = nn.Sequential(
31
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
32
+ spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
33
+ nn.LeakyReLU(0.2, True),
34
+ )
35
+
36
+ # for SFT scale and shift
37
+ self.scale_block = nn.Sequential(
38
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
39
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
40
+ self.shift_block = nn.Sequential(
41
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
42
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
43
+ # The official codes use sigmoid for shift block, do not know why
44
+
45
+ def forward(self, x, updated_feat):
46
+ out = self.conv1(x)
47
+ # SFT
48
+ scale = self.scale_block(updated_feat)
49
+ shift = self.shift_block(updated_feat)
50
+ out = out * scale + shift
51
+ # upsample
52
+ out = self.convup(out)
53
+ return out
54
+
55
+
56
+ @ARCH_REGISTRY.register()
57
+ class DFDNet(nn.Module):
58
+ """DFDNet: Deep Face Dictionary Network.
59
+
60
+ It only processes faces with 512x512 size.
61
+
62
+ Args:
63
+ num_feat (int): Number of feature channels.
64
+ dict_path (str): Path to the facial component dictionary.
65
+ """
66
+
67
+ def __init__(self, num_feat, dict_path):
68
+ super().__init__()
69
+ self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
70
+ # part_sizes: [80, 80, 50, 110]
71
+ channel_sizes = [128, 256, 512, 512]
72
+ self.feature_sizes = np.array([256, 128, 64, 32])
73
+ self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
74
+ self.flag_dict_device = False
75
+
76
+ # dict
77
+ self.dict = torch.load(dict_path)
78
+
79
+ # vgg face extractor
80
+ self.vgg_extractor = VGGFeatureExtractor(
81
+ layer_name_list=self.vgg_layers,
82
+ vgg_type='vgg19',
83
+ use_input_norm=True,
84
+ range_norm=True,
85
+ requires_grad=False)
86
+
87
+ # attention block for fusing dictionary features and input features
88
+ self.attn_blocks = nn.ModuleDict()
89
+ for idx, feat_size in enumerate(self.feature_sizes):
90
+ for name in self.parts:
91
+ self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
92
+
93
+ # multi scale dilation block
94
+ self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
95
+
96
+ # upsampling and reconstruction
97
+ self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
98
+ self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
99
+ self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
100
+ self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
101
+ self.upsample4 = nn.Sequential(
102
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
103
+ UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
104
+
105
+ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
106
+ """swap the features from the dictionary."""
107
+ # get the original vgg features
108
+ part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
109
+ # resize original vgg features
110
+ part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
111
+ # use adaptive instance normalization to adjust color and illuminations
112
+ dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
113
+ # get similarity scores
114
+ similarity_score = F.conv2d(part_resize_feat, dict_feat)
115
+ similarity_score = F.softmax(similarity_score.view(-1), dim=0)
116
+ # select the most similar features in the dict (after norm)
117
+ select_idx = torch.argmax(similarity_score)
118
+ swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
119
+ # attention
120
+ attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
121
+ attn_feat = attn * swap_feat
122
+ # update features
123
+ updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
124
+ return updated_feat
125
+
126
+ def put_dict_to_device(self, x):
127
+ if self.flag_dict_device is False:
128
+ for k, v in self.dict.items():
129
+ for kk, vv in v.items():
130
+ self.dict[k][kk] = vv.to(x)
131
+ self.flag_dict_device = True
132
+
133
+ def forward(self, x, part_locations):
134
+ """
135
+ Now only support testing with batch size = 0.
136
+
137
+ Args:
138
+ x (Tensor): Input faces with shape (b, c, 512, 512).
139
+ part_locations (list[Tensor]): Part locations.
140
+ """
141
+ self.put_dict_to_device(x)
142
+ # extract vggface features
143
+ vgg_features = self.vgg_extractor(x)
144
+ # update vggface features using the dictionary for each part
145
+ updated_vgg_features = []
146
+ batch = 0 # only supports testing with batch size = 0
147
+ for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
148
+ dict_features = self.dict[f'{f_size}']
149
+ vgg_feat = vgg_features[vgg_layer]
150
+ updated_feat = vgg_feat.clone()
151
+
152
+ # swap features from dictionary
153
+ for part_idx, part_name in enumerate(self.parts):
154
+ location = (part_locations[part_idx][batch] // (512 / f_size)).int()
155
+ updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
156
+ f_size)
157
+
158
+ updated_vgg_features.append(updated_feat)
159
+
160
+ vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
161
+ # use updated vgg features to modulate the upsampled features with
162
+ # SFT (Spatial Feature Transform) scaling and shifting manner.
163
+ upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
164
+ upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
165
+ upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
166
+ upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
167
+ out = self.upsample4(upsampled_feat)
168
+
169
+ return out
basicsr/archs/dfdnet_util.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+ from torch.nn.utils.spectral_norm import spectral_norm
6
+
7
+
8
+ class BlurFunctionBackward(Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, grad_output, kernel, kernel_flip):
12
+ ctx.save_for_backward(kernel, kernel_flip)
13
+ grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
14
+ return grad_input
15
+
16
+ @staticmethod
17
+ def backward(ctx, gradgrad_output):
18
+ kernel, _ = ctx.saved_tensors
19
+ grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
20
+ return grad_input, None, None
21
+
22
+
23
+ class BlurFunction(Function):
24
+
25
+ @staticmethod
26
+ def forward(ctx, x, kernel, kernel_flip):
27
+ ctx.save_for_backward(kernel, kernel_flip)
28
+ output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
29
+ return output
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ kernel, kernel_flip = ctx.saved_tensors
34
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
35
+ return grad_input, None, None
36
+
37
+
38
+ blur = BlurFunction.apply
39
+
40
+
41
+ class Blur(nn.Module):
42
+
43
+ def __init__(self, channel):
44
+ super().__init__()
45
+ kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
46
+ kernel = kernel.view(1, 1, 3, 3)
47
+ kernel = kernel / kernel.sum()
48
+ kernel_flip = torch.flip(kernel, [2, 3])
49
+
50
+ self.kernel = kernel.repeat(channel, 1, 1, 1)
51
+ self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
52
+
53
+ def forward(self, x):
54
+ return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
55
+
56
+
57
+ def calc_mean_std(feat, eps=1e-5):
58
+ """Calculate mean and std for adaptive_instance_normalization.
59
+
60
+ Args:
61
+ feat (Tensor): 4D tensor.
62
+ eps (float): A small value added to the variance to avoid
63
+ divide-by-zero. Default: 1e-5.
64
+ """
65
+ size = feat.size()
66
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
67
+ n, c = size[:2]
68
+ feat_var = feat.view(n, c, -1).var(dim=2) + eps
69
+ feat_std = feat_var.sqrt().view(n, c, 1, 1)
70
+ feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
71
+ return feat_mean, feat_std
72
+
73
+
74
+ def adaptive_instance_normalization(content_feat, style_feat):
75
+ """Adaptive instance normalization.
76
+
77
+ Adjust the reference features to have the similar color and illuminations
78
+ as those in the degradate features.
79
+
80
+ Args:
81
+ content_feat (Tensor): The reference feature.
82
+ style_feat (Tensor): The degradate features.
83
+ """
84
+ size = content_feat.size()
85
+ style_mean, style_std = calc_mean_std(style_feat)
86
+ content_mean, content_std = calc_mean_std(content_feat)
87
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
88
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
89
+
90
+
91
+ def AttentionBlock(in_channel):
92
+ return nn.Sequential(
93
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
94
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
95
+
96
+
97
+ def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
98
+ """Conv block used in MSDilationBlock."""
99
+
100
+ return nn.Sequential(
101
+ spectral_norm(
102
+ nn.Conv2d(
103
+ in_channels,
104
+ out_channels,
105
+ kernel_size=kernel_size,
106
+ stride=stride,
107
+ dilation=dilation,
108
+ padding=((kernel_size - 1) // 2) * dilation,
109
+ bias=bias)),
110
+ nn.LeakyReLU(0.2),
111
+ spectral_norm(
112
+ nn.Conv2d(
113
+ out_channels,
114
+ out_channels,
115
+ kernel_size=kernel_size,
116
+ stride=stride,
117
+ dilation=dilation,
118
+ padding=((kernel_size - 1) // 2) * dilation,
119
+ bias=bias)),
120
+ )
121
+
122
+
123
+ class MSDilationBlock(nn.Module):
124
+ """Multi-scale dilation block."""
125
+
126
+ def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
127
+ super(MSDilationBlock, self).__init__()
128
+
129
+ self.conv_blocks = nn.ModuleList()
130
+ for i in range(4):
131
+ self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
132
+ self.conv_fusion = spectral_norm(
133
+ nn.Conv2d(
134
+ in_channels * 4,
135
+ in_channels,
136
+ kernel_size=kernel_size,
137
+ stride=1,
138
+ padding=(kernel_size - 1) // 2,
139
+ bias=bias))
140
+
141
+ def forward(self, x):
142
+ out = []
143
+ for i in range(4):
144
+ out.append(self.conv_blocks[i](x))
145
+ out = torch.cat(out, 1)
146
+ out = self.conv_fusion(out) + x
147
+ return out
148
+
149
+
150
+ class UpResBlock(nn.Module):
151
+
152
+ def __init__(self, in_channel):
153
+ super(UpResBlock, self).__init__()
154
+ self.body = nn.Sequential(
155
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
156
+ nn.LeakyReLU(0.2, True),
157
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
158
+ )
159
+
160
+ def forward(self, x):
161
+ out = x + self.body(x)
162
+ return out
basicsr/archs/discriminator_arch.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+ from torch.nn.utils import spectral_norm
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ @ARCH_REGISTRY.register()
9
+ class VGGStyleDiscriminator(nn.Module):
10
+ """VGG style discriminator with input size 128 x 128 or 256 x 256.
11
+
12
+ It is used to train SRGAN, ESRGAN, and VideoGAN.
13
+
14
+ Args:
15
+ num_in_ch (int): Channel number of inputs. Default: 3.
16
+ num_feat (int): Channel number of base intermediate features.Default: 64.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat, input_size=128):
20
+ super(VGGStyleDiscriminator, self).__init__()
21
+ self.input_size = input_size
22
+ assert self.input_size == 128 or self.input_size == 256, (
23
+ f'input size must be 128 or 256, but received {input_size}')
24
+
25
+ self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
26
+ self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
27
+ self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
28
+
29
+ self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
30
+ self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
31
+ self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
32
+ self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
33
+
34
+ self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
35
+ self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
36
+ self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
37
+ self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
38
+
39
+ self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
40
+ self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
41
+ self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
42
+ self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
43
+
44
+ self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
45
+ self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
46
+ self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
47
+ self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
48
+
49
+ if self.input_size == 256:
50
+ self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
51
+ self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
52
+ self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
53
+ self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
54
+
55
+ self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
56
+ self.linear2 = nn.Linear(100, 1)
57
+
58
+ # activation function
59
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
60
+
61
+ def forward(self, x):
62
+ assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
63
+
64
+ feat = self.lrelu(self.conv0_0(x))
65
+ feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
66
+
67
+ feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
68
+ feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
69
+
70
+ feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
71
+ feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
72
+
73
+ feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
74
+ feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
75
+
76
+ feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
77
+ feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
78
+
79
+ if self.input_size == 256:
80
+ feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
81
+ feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
82
+
83
+ # spatial size: (4, 4)
84
+ feat = feat.view(feat.size(0), -1)
85
+ feat = self.lrelu(self.linear1(feat))
86
+ out = self.linear2(feat)
87
+ return out
88
+
89
+
90
+ @ARCH_REGISTRY.register(suffix='basicsr')
91
+ class UNetDiscriminatorSN(nn.Module):
92
+ """Defines a U-Net discriminator with spectral normalization (SN)
93
+
94
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
95
+
96
+ Arg:
97
+ num_in_ch (int): Channel number of inputs. Default: 3.
98
+ num_feat (int): Channel number of base intermediate features. Default: 64.
99
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
100
+ """
101
+
102
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
103
+ super(UNetDiscriminatorSN, self).__init__()
104
+ self.skip_connection = skip_connection
105
+ norm = spectral_norm
106
+ # the first convolution
107
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
108
+ # downsample
109
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
110
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
111
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
112
+ # upsample
113
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
114
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
115
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
116
+ # extra convolutions
117
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
118
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
119
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
120
+
121
+ def forward(self, x):
122
+ # downsample
123
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
124
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
125
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
126
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
127
+
128
+ # upsample
129
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
130
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
131
+
132
+ if self.skip_connection:
133
+ x4 = x4 + x2
134
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
135
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
136
+
137
+ if self.skip_connection:
138
+ x5 = x5 + x1
139
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
140
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
141
+
142
+ if self.skip_connection:
143
+ x6 = x6 + x0
144
+
145
+ # extra convolutions
146
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
147
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
148
+ out = self.conv9(out)
149
+
150
+ return out
basicsr/archs/duf_arch.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+
9
+ class DenseBlocksTemporalReduce(nn.Module):
10
+ """A concatenation of 3 dense blocks with reduction in temporal dimension.
11
+
12
+ Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
13
+
14
+ Args:
15
+ num_feat (int): Number of channels in the blocks. Default: 64.
16
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32
17
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
18
+ Set to false if you want to train from scratch. Default: False.
19
+ """
20
+
21
+ def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
22
+ super(DenseBlocksTemporalReduce, self).__init__()
23
+ if adapt_official_weights:
24
+ eps = 1e-3
25
+ momentum = 1e-3
26
+ else: # pytorch default values
27
+ eps = 1e-05
28
+ momentum = 0.1
29
+
30
+ self.temporal_reduce1 = nn.Sequential(
31
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
32
+ nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
33
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
34
+ nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
35
+
36
+ self.temporal_reduce2 = nn.Sequential(
37
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
38
+ nn.Conv3d(
39
+ num_feat + num_grow_ch,
40
+ num_feat + num_grow_ch, (1, 1, 1),
41
+ stride=(1, 1, 1),
42
+ padding=(0, 0, 0),
43
+ bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
44
+ nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
45
+
46
+ self.temporal_reduce3 = nn.Sequential(
47
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
48
+ nn.Conv3d(
49
+ num_feat + 2 * num_grow_ch,
50
+ num_feat + 2 * num_grow_ch, (1, 1, 1),
51
+ stride=(1, 1, 1),
52
+ padding=(0, 0, 0),
53
+ bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
54
+ nn.ReLU(inplace=True),
55
+ nn.Conv3d(
56
+ num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
57
+
58
+ def forward(self, x):
59
+ """
60
+ Args:
61
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
62
+
63
+ Returns:
64
+ Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
65
+ """
66
+ x1 = self.temporal_reduce1(x)
67
+ x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
68
+
69
+ x2 = self.temporal_reduce2(x1)
70
+ x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
71
+
72
+ x3 = self.temporal_reduce3(x2)
73
+ x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
74
+
75
+ return x3
76
+
77
+
78
+ class DenseBlocks(nn.Module):
79
+ """ A concatenation of N dense blocks.
80
+
81
+ Args:
82
+ num_feat (int): Number of channels in the blocks. Default: 64.
83
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
84
+ num_block (int): Number of dense blocks. The values are:
85
+ DUF-S (16 layers): 3
86
+ DUF-M (18 layers): 9
87
+ DUF-L (52 layers): 21
88
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
89
+ Set to false if you want to train from scratch. Default: False.
90
+ """
91
+
92
+ def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
93
+ super(DenseBlocks, self).__init__()
94
+ if adapt_official_weights:
95
+ eps = 1e-3
96
+ momentum = 1e-3
97
+ else: # pytorch default values
98
+ eps = 1e-05
99
+ momentum = 0.1
100
+
101
+ self.dense_blocks = nn.ModuleList()
102
+ for i in range(0, num_block):
103
+ self.dense_blocks.append(
104
+ nn.Sequential(
105
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
106
+ nn.Conv3d(
107
+ num_feat + i * num_grow_ch,
108
+ num_feat + i * num_grow_ch, (1, 1, 1),
109
+ stride=(1, 1, 1),
110
+ padding=(0, 0, 0),
111
+ bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
112
+ nn.ReLU(inplace=True),
113
+ nn.Conv3d(
114
+ num_feat + i * num_grow_ch,
115
+ num_grow_ch, (3, 3, 3),
116
+ stride=(1, 1, 1),
117
+ padding=(1, 1, 1),
118
+ bias=True)))
119
+
120
+ def forward(self, x):
121
+ """
122
+ Args:
123
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
124
+
125
+ Returns:
126
+ Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
127
+ """
128
+ for i in range(0, len(self.dense_blocks)):
129
+ y = self.dense_blocks[i](x)
130
+ x = torch.cat((x, y), 1)
131
+ return x
132
+
133
+
134
+ class DynamicUpsamplingFilter(nn.Module):
135
+ """Dynamic upsampling filter used in DUF.
136
+
137
+ Reference: https://github.com/yhjo09/VSR-DUF
138
+
139
+ It only supports input with 3 channels. And it applies the same filters to 3 channels.
140
+
141
+ Args:
142
+ filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
143
+ """
144
+
145
+ def __init__(self, filter_size=(5, 5)):
146
+ super(DynamicUpsamplingFilter, self).__init__()
147
+ if not isinstance(filter_size, tuple):
148
+ raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
149
+ if len(filter_size) != 2:
150
+ raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
151
+ # generate a local expansion filter, similar to im2col
152
+ self.filter_size = filter_size
153
+ filter_prod = np.prod(filter_size)
154
+ expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
155
+ self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
156
+
157
+ def forward(self, x, filters):
158
+ """Forward function for DynamicUpsamplingFilter.
159
+
160
+ Args:
161
+ x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
162
+ filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
163
+ filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
164
+ upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
165
+ e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
166
+
167
+ Returns:
168
+ Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
169
+ """
170
+ n, filter_prod, upsampling_square, h, w = filters.size()
171
+ kh, kw = self.filter_size
172
+ expanded_input = F.conv2d(
173
+ x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
174
+ expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
175
+ 2) # (n, h, w, 3, filter_prod)
176
+ filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
177
+ out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
178
+ return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
179
+
180
+
181
+ @ARCH_REGISTRY.register()
182
+ class DUF(nn.Module):
183
+ """Network architecture for DUF
184
+
185
+ ``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
186
+
187
+ Reference: https://github.com/yhjo09/VSR-DUF
188
+
189
+ For all the models below, 'adapt_official_weights' is only necessary when
190
+ loading the weights converted from the official TensorFlow weights.
191
+ Please set it to False if you are training the model from scratch.
192
+
193
+ There are three models with different model size: DUF16Layers, DUF28Layers,
194
+ and DUF52Layers. This class is the base class for these models.
195
+
196
+ Args:
197
+ scale (int): The upsampling factor. Default: 4.
198
+ num_layer (int): The number of layers. Default: 52.
199
+ adapt_official_weights_weights (bool): Whether to adapt the weights
200
+ translated from the official implementation. Set to false if you
201
+ want to train from scratch. Default: False.
202
+ """
203
+
204
+ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
205
+ super(DUF, self).__init__()
206
+ self.scale = scale
207
+ if adapt_official_weights:
208
+ eps = 1e-3
209
+ momentum = 1e-3
210
+ else: # pytorch default values
211
+ eps = 1e-05
212
+ momentum = 0.1
213
+
214
+ self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
215
+ self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
216
+
217
+ if num_layer == 16:
218
+ num_block = 3
219
+ num_grow_ch = 32
220
+ elif num_layer == 28:
221
+ num_block = 9
222
+ num_grow_ch = 16
223
+ elif num_layer == 52:
224
+ num_block = 21
225
+ num_grow_ch = 16
226
+ else:
227
+ raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
228
+
229
+ self.dense_block1 = DenseBlocks(
230
+ num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
231
+ adapt_official_weights=adapt_official_weights) # T = 7
232
+ self.dense_block2 = DenseBlocksTemporalReduce(
233
+ 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
234
+ channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
235
+ self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
236
+ self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
237
+
238
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
239
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
240
+
241
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
242
+ self.conv3d_f2 = nn.Conv3d(
243
+ 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
244
+
245
+ def forward(self, x):
246
+ """
247
+ Args:
248
+ x (Tensor): Input with shape (b, 7, c, h, w)
249
+
250
+ Returns:
251
+ Tensor: Output with shape (b, c, h * scale, w * scale)
252
+ """
253
+ num_batches, num_imgs, _, h, w = x.size()
254
+
255
+ x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
256
+ x_center = x[:, :, num_imgs // 2, :, :]
257
+
258
+ x = self.conv3d1(x)
259
+ x = self.dense_block1(x)
260
+ x = self.dense_block2(x)
261
+ x = F.relu(self.bn3d2(x), inplace=True)
262
+ x = F.relu(self.conv3d2(x), inplace=True)
263
+
264
+ # residual image
265
+ res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
266
+
267
+ # filter
268
+ filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
269
+ filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
270
+
271
+ # dynamic filter
272
+ out = self.dynamic_filter(x_center, filter_)
273
+ out += res.squeeze_(2)
274
+ out = F.pixel_shuffle(out, self.scale)
275
+
276
+ return out
basicsr/archs/ecbsr_arch.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ class SeqConv3x3(nn.Module):
9
+ """The re-parameterizable block used in the ECBSR architecture.
10
+
11
+ ``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
12
+
13
+ Reference: https://github.com/xindongzhang/ECBSR
14
+
15
+ Args:
16
+ seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
17
+ in_channels (int): Channel number of input.
18
+ out_channels (int): Channel number of output.
19
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
20
+ """
21
+
22
+ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
23
+ super(SeqConv3x3, self).__init__()
24
+ self.seq_type = seq_type
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+
28
+ if self.seq_type == 'conv1x1-conv3x3':
29
+ self.mid_planes = int(out_channels * depth_multiplier)
30
+ conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
31
+ self.k0 = conv0.weight
32
+ self.b0 = conv0.bias
33
+
34
+ conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
35
+ self.k1 = conv1.weight
36
+ self.b1 = conv1.bias
37
+
38
+ elif self.seq_type == 'conv1x1-sobelx':
39
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
40
+ self.k0 = conv0.weight
41
+ self.b0 = conv0.bias
42
+
43
+ # init scale and bias
44
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
45
+ self.scale = nn.Parameter(scale)
46
+ bias = torch.randn(self.out_channels) * 1e-3
47
+ bias = torch.reshape(bias, (self.out_channels, ))
48
+ self.bias = nn.Parameter(bias)
49
+ # init mask
50
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
51
+ for i in range(self.out_channels):
52
+ self.mask[i, 0, 0, 0] = 1.0
53
+ self.mask[i, 0, 1, 0] = 2.0
54
+ self.mask[i, 0, 2, 0] = 1.0
55
+ self.mask[i, 0, 0, 2] = -1.0
56
+ self.mask[i, 0, 1, 2] = -2.0
57
+ self.mask[i, 0, 2, 2] = -1.0
58
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
59
+
60
+ elif self.seq_type == 'conv1x1-sobely':
61
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
62
+ self.k0 = conv0.weight
63
+ self.b0 = conv0.bias
64
+
65
+ # init scale and bias
66
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
67
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
68
+ bias = torch.randn(self.out_channels) * 1e-3
69
+ bias = torch.reshape(bias, (self.out_channels, ))
70
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
71
+ # init mask
72
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
73
+ for i in range(self.out_channels):
74
+ self.mask[i, 0, 0, 0] = 1.0
75
+ self.mask[i, 0, 0, 1] = 2.0
76
+ self.mask[i, 0, 0, 2] = 1.0
77
+ self.mask[i, 0, 2, 0] = -1.0
78
+ self.mask[i, 0, 2, 1] = -2.0
79
+ self.mask[i, 0, 2, 2] = -1.0
80
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
81
+
82
+ elif self.seq_type == 'conv1x1-laplacian':
83
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
84
+ self.k0 = conv0.weight
85
+ self.b0 = conv0.bias
86
+
87
+ # init scale and bias
88
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
89
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
90
+ bias = torch.randn(self.out_channels) * 1e-3
91
+ bias = torch.reshape(bias, (self.out_channels, ))
92
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
93
+ # init mask
94
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
95
+ for i in range(self.out_channels):
96
+ self.mask[i, 0, 0, 1] = 1.0
97
+ self.mask[i, 0, 1, 0] = 1.0
98
+ self.mask[i, 0, 1, 2] = 1.0
99
+ self.mask[i, 0, 2, 1] = 1.0
100
+ self.mask[i, 0, 1, 1] = -4.0
101
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
102
+ else:
103
+ raise ValueError('The type of seqconv is not supported!')
104
+
105
+ def forward(self, x):
106
+ if self.seq_type == 'conv1x1-conv3x3':
107
+ # conv-1x1
108
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
109
+ # explicitly padding with bias
110
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
111
+ b0_pad = self.b0.view(1, -1, 1, 1)
112
+ y0[:, :, 0:1, :] = b0_pad
113
+ y0[:, :, -1:, :] = b0_pad
114
+ y0[:, :, :, 0:1] = b0_pad
115
+ y0[:, :, :, -1:] = b0_pad
116
+ # conv-3x3
117
+ y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
118
+ else:
119
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
120
+ # explicitly padding with bias
121
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
122
+ b0_pad = self.b0.view(1, -1, 1, 1)
123
+ y0[:, :, 0:1, :] = b0_pad
124
+ y0[:, :, -1:, :] = b0_pad
125
+ y0[:, :, :, 0:1] = b0_pad
126
+ y0[:, :, :, -1:] = b0_pad
127
+ # conv-3x3
128
+ y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
129
+ return y1
130
+
131
+ def rep_params(self):
132
+ device = self.k0.get_device()
133
+ if device < 0:
134
+ device = None
135
+
136
+ if self.seq_type == 'conv1x1-conv3x3':
137
+ # re-param conv kernel
138
+ rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
139
+ # re-param conv bias
140
+ rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
141
+ rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
142
+ else:
143
+ tmp = self.scale * self.mask
144
+ k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
145
+ for i in range(self.out_channels):
146
+ k1[i, i, :, :] = tmp[i, 0, :, :]
147
+ b1 = self.bias
148
+ # re-param conv kernel
149
+ rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
150
+ # re-param conv bias
151
+ rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
152
+ rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
153
+ return rep_weight, rep_bias
154
+
155
+
156
+ class ECB(nn.Module):
157
+ """The ECB block used in the ECBSR architecture.
158
+
159
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
160
+ Ref git repo: https://github.com/xindongzhang/ECBSR
161
+
162
+ Args:
163
+ in_channels (int): Channel number of input.
164
+ out_channels (int): Channel number of output.
165
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
166
+ act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
167
+ with_idt (bool): Whether to use identity connection. Default: False.
168
+ """
169
+
170
+ def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
171
+ super(ECB, self).__init__()
172
+
173
+ self.depth_multiplier = depth_multiplier
174
+ self.in_channels = in_channels
175
+ self.out_channels = out_channels
176
+ self.act_type = act_type
177
+
178
+ if with_idt and (self.in_channels == self.out_channels):
179
+ self.with_idt = True
180
+ else:
181
+ self.with_idt = False
182
+
183
+ self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
184
+ self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
185
+ self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
186
+ self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
187
+ self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
188
+
189
+ if self.act_type == 'prelu':
190
+ self.act = nn.PReLU(num_parameters=self.out_channels)
191
+ elif self.act_type == 'relu':
192
+ self.act = nn.ReLU(inplace=True)
193
+ elif self.act_type == 'rrelu':
194
+ self.act = nn.RReLU(lower=-0.05, upper=0.05)
195
+ elif self.act_type == 'softplus':
196
+ self.act = nn.Softplus()
197
+ elif self.act_type == 'linear':
198
+ pass
199
+ else:
200
+ raise ValueError('The type of activation if not support!')
201
+
202
+ def forward(self, x):
203
+ if self.training:
204
+ y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
205
+ if self.with_idt:
206
+ y += x
207
+ else:
208
+ rep_weight, rep_bias = self.rep_params()
209
+ y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
210
+ if self.act_type != 'linear':
211
+ y = self.act(y)
212
+ return y
213
+
214
+ def rep_params(self):
215
+ weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
216
+ weight1, bias1 = self.conv1x1_3x3.rep_params()
217
+ weight2, bias2 = self.conv1x1_sbx.rep_params()
218
+ weight3, bias3 = self.conv1x1_sby.rep_params()
219
+ weight4, bias4 = self.conv1x1_lpl.rep_params()
220
+ rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
221
+ bias0 + bias1 + bias2 + bias3 + bias4)
222
+
223
+ if self.with_idt:
224
+ device = rep_weight.get_device()
225
+ if device < 0:
226
+ device = None
227
+ weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
228
+ for i in range(self.out_channels):
229
+ weight_idt[i, i, 1, 1] = 1.0
230
+ bias_idt = 0.0
231
+ rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
232
+ return rep_weight, rep_bias
233
+
234
+
235
+ @ARCH_REGISTRY.register()
236
+ class ECBSR(nn.Module):
237
+ """ECBSR architecture.
238
+
239
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
240
+ Ref git repo: https://github.com/xindongzhang/ECBSR
241
+
242
+ Args:
243
+ num_in_ch (int): Channel number of inputs.
244
+ num_out_ch (int): Channel number of outputs.
245
+ num_block (int): Block number in the trunk network.
246
+ num_channel (int): Channel number.
247
+ with_idt (bool): Whether use identity in convolution layers.
248
+ act_type (str): Activation type.
249
+ scale (int): Upsampling factor.
250
+ """
251
+
252
+ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
253
+ super(ECBSR, self).__init__()
254
+ self.num_in_ch = num_in_ch
255
+ self.scale = scale
256
+
257
+ backbone = []
258
+ backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
259
+ for _ in range(num_block):
260
+ backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
261
+ backbone += [
262
+ ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
263
+ ]
264
+
265
+ self.backbone = nn.Sequential(*backbone)
266
+ self.upsampler = nn.PixelShuffle(scale)
267
+
268
+ def forward(self, x):
269
+ if self.num_in_ch > 1:
270
+ shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
271
+ else:
272
+ shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
273
+ y = self.backbone(x) + shortcut
274
+ y = self.upsampler(y)
275
+ return y
basicsr/archs/edsr_arch.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+
4
+ from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ @ARCH_REGISTRY.register()
9
+ class EDSR(nn.Module):
10
+ """EDSR network structure.
11
+
12
+ Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
13
+ Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
14
+
15
+ Args:
16
+ num_in_ch (int): Channel number of inputs.
17
+ num_out_ch (int): Channel number of outputs.
18
+ num_feat (int): Channel number of intermediate features.
19
+ Default: 64.
20
+ num_block (int): Block number in the trunk network. Default: 16.
21
+ upscale (int): Upsampling factor. Support 2^n and 3.
22
+ Default: 4.
23
+ res_scale (float): Used to scale the residual in residual block.
24
+ Default: 1.
25
+ img_range (float): Image range. Default: 255.
26
+ rgb_mean (tuple[float]): Image mean in RGB orders.
27
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
28
+ """
29
+
30
+ def __init__(self,
31
+ num_in_ch,
32
+ num_out_ch,
33
+ num_feat=64,
34
+ num_block=16,
35
+ upscale=4,
36
+ res_scale=1,
37
+ img_range=255.,
38
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
39
+ super(EDSR, self).__init__()
40
+
41
+ self.img_range = img_range
42
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
43
+
44
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
45
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
46
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
47
+ self.upsample = Upsample(upscale, num_feat)
48
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
49
+
50
+ def forward(self, x):
51
+ self.mean = self.mean.type_as(x)
52
+
53
+ x = (x - self.mean) * self.img_range
54
+ x = self.conv_first(x)
55
+ res = self.conv_after_body(self.body(x))
56
+ res += x
57
+
58
+ x = self.conv_last(self.upsample(res))
59
+ x = x / self.img_range + self.mean
60
+
61
+ return x
basicsr/archs/edvr_arch.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
7
+
8
+
9
+ class PCDAlignment(nn.Module):
10
+ """Alignment module using Pyramid, Cascading and Deformable convolution
11
+ (PCD). It is used in EDVR.
12
+
13
+ ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
14
+
15
+ Args:
16
+ num_feat (int): Channel number of middle features. Default: 64.
17
+ deformable_groups (int): Deformable groups. Defaults: 8.
18
+ """
19
+
20
+ def __init__(self, num_feat=64, deformable_groups=8):
21
+ super(PCDAlignment, self).__init__()
22
+
23
+ # Pyramid has three levels:
24
+ # L3: level 3, 1/4 spatial size
25
+ # L2: level 2, 1/2 spatial size
26
+ # L1: level 1, original spatial size
27
+ self.offset_conv1 = nn.ModuleDict()
28
+ self.offset_conv2 = nn.ModuleDict()
29
+ self.offset_conv3 = nn.ModuleDict()
30
+ self.dcn_pack = nn.ModuleDict()
31
+ self.feat_conv = nn.ModuleDict()
32
+
33
+ # Pyramids
34
+ for i in range(3, 0, -1):
35
+ level = f'l{i}'
36
+ self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
37
+ if i == 3:
38
+ self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
39
+ else:
40
+ self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
41
+ self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
42
+ self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
43
+
44
+ if i < 3:
45
+ self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
46
+
47
+ # Cascading dcn
48
+ self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
49
+ self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
50
+ self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
51
+
52
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
53
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
54
+
55
+ def forward(self, nbr_feat_l, ref_feat_l):
56
+ """Align neighboring frame features to the reference frame features.
57
+
58
+ Args:
59
+ nbr_feat_l (list[Tensor]): Neighboring feature list. It
60
+ contains three pyramid levels (L1, L2, L3),
61
+ each with shape (b, c, h, w).
62
+ ref_feat_l (list[Tensor]): Reference feature list. It
63
+ contains three pyramid levels (L1, L2, L3),
64
+ each with shape (b, c, h, w).
65
+
66
+ Returns:
67
+ Tensor: Aligned features.
68
+ """
69
+ # Pyramids
70
+ upsampled_offset, upsampled_feat = None, None
71
+ for i in range(3, 0, -1):
72
+ level = f'l{i}'
73
+ offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
74
+ offset = self.lrelu(self.offset_conv1[level](offset))
75
+ if i == 3:
76
+ offset = self.lrelu(self.offset_conv2[level](offset))
77
+ else:
78
+ offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
79
+ offset = self.lrelu(self.offset_conv3[level](offset))
80
+
81
+ feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
82
+ if i < 3:
83
+ feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
84
+ if i > 1:
85
+ feat = self.lrelu(feat)
86
+
87
+ if i > 1: # upsample offset and features
88
+ # x2: when we upsample the offset, we should also enlarge
89
+ # the magnitude.
90
+ upsampled_offset = self.upsample(offset) * 2
91
+ upsampled_feat = self.upsample(feat)
92
+
93
+ # Cascading
94
+ offset = torch.cat([feat, ref_feat_l[0]], dim=1)
95
+ offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
96
+ feat = self.lrelu(self.cas_dcnpack(feat, offset))
97
+ return feat
98
+
99
+
100
+ class TSAFusion(nn.Module):
101
+ """Temporal Spatial Attention (TSA) fusion module.
102
+
103
+ Temporal: Calculate the correlation between center frame and
104
+ neighboring frames;
105
+ Spatial: It has 3 pyramid levels, the attention is similar to SFT.
106
+ (SFT: Recovering realistic texture in image super-resolution by deep
107
+ spatial feature transform.)
108
+
109
+ Args:
110
+ num_feat (int): Channel number of middle features. Default: 64.
111
+ num_frame (int): Number of frames. Default: 5.
112
+ center_frame_idx (int): The index of center frame. Default: 2.
113
+ """
114
+
115
+ def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
116
+ super(TSAFusion, self).__init__()
117
+ self.center_frame_idx = center_frame_idx
118
+ # temporal attention (before fusion conv)
119
+ self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
120
+ self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
+ self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
122
+
123
+ # spatial attention (after fusion conv)
124
+ self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
125
+ self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
126
+ self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
127
+ self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
128
+ self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
129
+ self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
130
+ self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
131
+ self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
132
+ self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
133
+ self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
134
+ self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
135
+ self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
136
+
137
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
138
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
139
+
140
+ def forward(self, aligned_feat):
141
+ """
142
+ Args:
143
+ aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
144
+
145
+ Returns:
146
+ Tensor: Features after TSA with the shape (b, c, h, w).
147
+ """
148
+ b, t, c, h, w = aligned_feat.size()
149
+ # temporal attention
150
+ embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
151
+ embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
152
+ embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
153
+
154
+ corr_l = [] # correlation list
155
+ for i in range(t):
156
+ emb_neighbor = embedding[:, i, :, :, :]
157
+ corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
158
+ corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
159
+ corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
160
+ corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
161
+ corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
162
+ aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
163
+
164
+ # fusion
165
+ feat = self.lrelu(self.feat_fusion(aligned_feat))
166
+
167
+ # spatial attention
168
+ attn = self.lrelu(self.spatial_attn1(aligned_feat))
169
+ attn_max = self.max_pool(attn)
170
+ attn_avg = self.avg_pool(attn)
171
+ attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
172
+ # pyramid levels
173
+ attn_level = self.lrelu(self.spatial_attn_l1(attn))
174
+ attn_max = self.max_pool(attn_level)
175
+ attn_avg = self.avg_pool(attn_level)
176
+ attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
177
+ attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
178
+ attn_level = self.upsample(attn_level)
179
+
180
+ attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
181
+ attn = self.lrelu(self.spatial_attn4(attn))
182
+ attn = self.upsample(attn)
183
+ attn = self.spatial_attn5(attn)
184
+ attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
185
+ attn = torch.sigmoid(attn)
186
+
187
+ # after initialization, * 2 makes (attn * 2) to be close to 1.
188
+ feat = feat * attn * 2 + attn_add
189
+ return feat
190
+
191
+
192
+ class PredeblurModule(nn.Module):
193
+ """Pre-dublur module.
194
+
195
+ Args:
196
+ num_in_ch (int): Channel number of input image. Default: 3.
197
+ num_feat (int): Channel number of intermediate features. Default: 64.
198
+ hr_in (bool): Whether the input has high resolution. Default: False.
199
+ """
200
+
201
+ def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
202
+ super(PredeblurModule, self).__init__()
203
+ self.hr_in = hr_in
204
+
205
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
206
+ if self.hr_in:
207
+ # downsample x4 by stride conv
208
+ self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
209
+ self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
210
+
211
+ # generate feature pyramid
212
+ self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
213
+ self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
214
+
215
+ self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
216
+ self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
217
+ self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
218
+ self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
219
+
220
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
221
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
222
+
223
+ def forward(self, x):
224
+ feat_l1 = self.lrelu(self.conv_first(x))
225
+ if self.hr_in:
226
+ feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
227
+ feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
228
+
229
+ # generate feature pyramid
230
+ feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
231
+ feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
232
+
233
+ feat_l3 = self.upsample(self.resblock_l3(feat_l3))
234
+ feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
235
+ feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
236
+
237
+ for i in range(2):
238
+ feat_l1 = self.resblock_l1[i](feat_l1)
239
+ feat_l1 = feat_l1 + feat_l2
240
+ for i in range(2, 5):
241
+ feat_l1 = self.resblock_l1[i](feat_l1)
242
+ return feat_l1
243
+
244
+
245
+ @ARCH_REGISTRY.register()
246
+ class EDVR(nn.Module):
247
+ """EDVR network structure for video super-resolution.
248
+
249
+ Now only support X4 upsampling factor.
250
+
251
+ ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
252
+
253
+ Args:
254
+ num_in_ch (int): Channel number of input image. Default: 3.
255
+ num_out_ch (int): Channel number of output image. Default: 3.
256
+ num_feat (int): Channel number of intermediate features. Default: 64.
257
+ num_frame (int): Number of input frames. Default: 5.
258
+ deformable_groups (int): Deformable groups. Defaults: 8.
259
+ num_extract_block (int): Number of blocks for feature extraction.
260
+ Default: 5.
261
+ num_reconstruct_block (int): Number of blocks for reconstruction.
262
+ Default: 10.
263
+ center_frame_idx (int): The index of center frame. Frame counting from
264
+ 0. Default: Middle of input frames.
265
+ hr_in (bool): Whether the input has high resolution. Default: False.
266
+ with_predeblur (bool): Whether has predeblur module.
267
+ Default: False.
268
+ with_tsa (bool): Whether has TSA module. Default: True.
269
+ """
270
+
271
+ def __init__(self,
272
+ num_in_ch=3,
273
+ num_out_ch=3,
274
+ num_feat=64,
275
+ num_frame=5,
276
+ deformable_groups=8,
277
+ num_extract_block=5,
278
+ num_reconstruct_block=10,
279
+ center_frame_idx=None,
280
+ hr_in=False,
281
+ with_predeblur=False,
282
+ with_tsa=True):
283
+ super(EDVR, self).__init__()
284
+ if center_frame_idx is None:
285
+ self.center_frame_idx = num_frame // 2
286
+ else:
287
+ self.center_frame_idx = center_frame_idx
288
+ self.hr_in = hr_in
289
+ self.with_predeblur = with_predeblur
290
+ self.with_tsa = with_tsa
291
+
292
+ # extract features for each frame
293
+ if self.with_predeblur:
294
+ self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
295
+ self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
296
+ else:
297
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
298
+
299
+ # extract pyramid features
300
+ self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
301
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
302
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
303
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
304
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
305
+
306
+ # pcd and tsa module
307
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
308
+ if self.with_tsa:
309
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
310
+ else:
311
+ self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
312
+
313
+ # reconstruction
314
+ self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
315
+ # upsample
316
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
317
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
318
+ self.pixel_shuffle = nn.PixelShuffle(2)
319
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
320
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
321
+
322
+ # activation function
323
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
324
+
325
+ def forward(self, x):
326
+ b, t, c, h, w = x.size()
327
+ if self.hr_in:
328
+ assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
329
+ else:
330
+ assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
331
+
332
+ x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
333
+
334
+ # extract features for each frame
335
+ # L1
336
+ if self.with_predeblur:
337
+ feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
338
+ if self.hr_in:
339
+ h, w = h // 4, w // 4
340
+ else:
341
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
342
+
343
+ feat_l1 = self.feature_extraction(feat_l1)
344
+ # L2
345
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
346
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
347
+ # L3
348
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
349
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
350
+
351
+ feat_l1 = feat_l1.view(b, t, -1, h, w)
352
+ feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
353
+ feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
354
+
355
+ # PCD alignment
356
+ ref_feat_l = [ # reference feature list
357
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
358
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
359
+ ]
360
+ aligned_feat = []
361
+ for i in range(t):
362
+ nbr_feat_l = [ # neighboring feature list
363
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
364
+ ]
365
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
366
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
367
+
368
+ if not self.with_tsa:
369
+ aligned_feat = aligned_feat.view(b, -1, h, w)
370
+ feat = self.fusion(aligned_feat)
371
+
372
+ out = self.reconstruction(feat)
373
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
374
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
375
+ out = self.lrelu(self.conv_hr(out))
376
+ out = self.conv_last(out)
377
+ if self.hr_in:
378
+ base = x_center
379
+ else:
380
+ base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
381
+ out += base
382
+ return out
basicsr/archs/hifacegan_arch.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+ from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
8
+
9
+
10
+ class SPADEGenerator(BaseNetwork):
11
+ """Generator with SPADEResBlock"""
12
+
13
+ def __init__(self,
14
+ num_in_ch=3,
15
+ num_feat=64,
16
+ use_vae=False,
17
+ z_dim=256,
18
+ crop_size=512,
19
+ norm_g='spectralspadesyncbatch3x3',
20
+ is_train=True,
21
+ init_train_phase=3): # progressive training disabled
22
+ super().__init__()
23
+ self.nf = num_feat
24
+ self.input_nc = num_in_ch
25
+ self.is_train = is_train
26
+ self.train_phase = init_train_phase
27
+
28
+ self.scale_ratio = 5 # hardcoded now
29
+ self.sw = crop_size // (2**self.scale_ratio)
30
+ self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
31
+
32
+ if use_vae:
33
+ # In case of VAE, we will sample from random z vector
34
+ self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
35
+ else:
36
+ # Otherwise, we make the network deterministic by starting with
37
+ # downsampled segmentation map instead of random z
38
+ self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
39
+
40
+ self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
41
+
42
+ self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
43
+ self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
44
+
45
+ self.ups = nn.ModuleList([
46
+ SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
47
+ SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
48
+ SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
49
+ SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
50
+ ])
51
+
52
+ self.to_rgbs = nn.ModuleList([
53
+ nn.Conv2d(8 * self.nf, 3, 3, padding=1),
54
+ nn.Conv2d(4 * self.nf, 3, 3, padding=1),
55
+ nn.Conv2d(2 * self.nf, 3, 3, padding=1),
56
+ nn.Conv2d(1 * self.nf, 3, 3, padding=1)
57
+ ])
58
+
59
+ self.up = nn.Upsample(scale_factor=2)
60
+
61
+ def encode(self, input_tensor):
62
+ """
63
+ Encode input_tensor into feature maps, can be overridden in derived classes
64
+ Default: nearest downsampling of 2**5 = 32 times
65
+ """
66
+ h, w = input_tensor.size()[-2:]
67
+ sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
68
+ x = F.interpolate(input_tensor, size=(sh, sw))
69
+ return self.fc(x)
70
+
71
+ def forward(self, x):
72
+ # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
73
+ seg = x
74
+
75
+ x = self.encode(x)
76
+ x = self.head_0(x, seg)
77
+
78
+ x = self.up(x)
79
+ x = self.g_middle_0(x, seg)
80
+ x = self.g_middle_1(x, seg)
81
+
82
+ if self.is_train:
83
+ phase = self.train_phase + 1
84
+ else:
85
+ phase = len(self.to_rgbs)
86
+
87
+ for i in range(phase):
88
+ x = self.up(x)
89
+ x = self.ups[i](x, seg)
90
+
91
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+ def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
97
+ """
98
+ A helper class for subspace visualization. Input and seg are different images.
99
+ For the first n levels (including encoder) we use input, for the rest we use seg.
100
+
101
+ If mode = 'progressive', the output's like: AAABBB
102
+ If mode = 'one_plug', the output's like: AAABAA
103
+ If mode = 'one_ablate', the output's like: BBBABB
104
+ """
105
+
106
+ if seg is None:
107
+ return self.forward(input_x)
108
+
109
+ if self.is_train:
110
+ phase = self.train_phase + 1
111
+ else:
112
+ phase = len(self.to_rgbs)
113
+
114
+ if mode == 'progressive':
115
+ n = max(min(n, 4 + phase), 0)
116
+ guide_list = [input_x] * n + [seg] * (4 + phase - n)
117
+ elif mode == 'one_plug':
118
+ n = max(min(n, 4 + phase - 1), 0)
119
+ guide_list = [seg] * (4 + phase)
120
+ guide_list[n] = input_x
121
+ elif mode == 'one_ablate':
122
+ if n > 3 + phase:
123
+ return self.forward(input_x)
124
+ guide_list = [input_x] * (4 + phase)
125
+ guide_list[n] = seg
126
+
127
+ x = self.encode(guide_list[0])
128
+ x = self.head_0(x, guide_list[1])
129
+
130
+ x = self.up(x)
131
+ x = self.g_middle_0(x, guide_list[2])
132
+ x = self.g_middle_1(x, guide_list[3])
133
+
134
+ for i in range(phase):
135
+ x = self.up(x)
136
+ x = self.ups[i](x, guide_list[4 + i])
137
+
138
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
139
+ x = torch.tanh(x)
140
+
141
+ return x
142
+
143
+
144
+ @ARCH_REGISTRY.register()
145
+ class HiFaceGAN(SPADEGenerator):
146
+ """
147
+ HiFaceGAN: SPADEGenerator with a learnable feature encoder
148
+ Current encoder design: LIPEncoder
149
+ """
150
+
151
+ def __init__(self,
152
+ num_in_ch=3,
153
+ num_feat=64,
154
+ use_vae=False,
155
+ z_dim=256,
156
+ crop_size=512,
157
+ norm_g='spectralspadesyncbatch3x3',
158
+ is_train=True,
159
+ init_train_phase=3):
160
+ super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
161
+ self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
162
+
163
+ def encode(self, input_tensor):
164
+ return self.lip_encoder(input_tensor)
165
+
166
+
167
+ @ARCH_REGISTRY.register()
168
+ class HiFaceGANDiscriminator(BaseNetwork):
169
+ """
170
+ Inspired by pix2pixHD multiscale discriminator.
171
+
172
+ Args:
173
+ num_in_ch (int): Channel number of inputs. Default: 3.
174
+ num_out_ch (int): Channel number of outputs. Default: 3.
175
+ conditional_d (bool): Whether use conditional discriminator.
176
+ Default: True.
177
+ num_d (int): Number of Multiscale discriminators. Default: 3.
178
+ n_layers_d (int): Number of downsample layers in each D. Default: 4.
179
+ num_feat (int): Channel number of base intermediate features.
180
+ Default: 64.
181
+ norm_d (str): String to determine normalization layers in D.
182
+ Choices: [spectral][instance/batch/syncbatch]
183
+ Default: 'spectralinstance'.
184
+ keep_features (bool): Keep intermediate features for matching loss, etc.
185
+ Default: True.
186
+ """
187
+
188
+ def __init__(self,
189
+ num_in_ch=3,
190
+ num_out_ch=3,
191
+ conditional_d=True,
192
+ num_d=2,
193
+ n_layers_d=4,
194
+ num_feat=64,
195
+ norm_d='spectralinstance',
196
+ keep_features=True):
197
+ super().__init__()
198
+ self.num_d = num_d
199
+
200
+ input_nc = num_in_ch
201
+ if conditional_d:
202
+ input_nc += num_out_ch
203
+
204
+ for i in range(num_d):
205
+ subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
206
+ self.add_module(f'discriminator_{i}', subnet_d)
207
+
208
+ def downsample(self, x):
209
+ return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
210
+
211
+ # Returns list of lists of discriminator outputs.
212
+ # The final result is of size opt.num_d x opt.n_layers_D
213
+ def forward(self, x):
214
+ result = []
215
+ for _, _net_d in self.named_children():
216
+ out = _net_d(x)
217
+ result.append(out)
218
+ x = self.downsample(x)
219
+
220
+ return result
221
+
222
+
223
+ class NLayerDiscriminator(BaseNetwork):
224
+ """Defines the PatchGAN discriminator with the specified arguments."""
225
+
226
+ def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
227
+ super().__init__()
228
+ kw = 4
229
+ padw = int(np.ceil((kw - 1.0) / 2))
230
+ nf = num_feat
231
+ self.keep_features = keep_features
232
+
233
+ norm_layer = get_nonspade_norm_layer(norm_d)
234
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
235
+
236
+ for n in range(1, n_layers_d):
237
+ nf_prev = nf
238
+ nf = min(nf * 2, 512)
239
+ stride = 1 if n == n_layers_d - 1 else 2
240
+ sequence += [[
241
+ norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
242
+ nn.LeakyReLU(0.2, False)
243
+ ]]
244
+
245
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
246
+
247
+ # We divide the layers into groups to extract intermediate layer outputs
248
+ for n in range(len(sequence)):
249
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
250
+
251
+ def forward(self, x):
252
+ results = [x]
253
+ for submodel in self.children():
254
+ intermediate_output = submodel(results[-1])
255
+ results.append(intermediate_output)
256
+
257
+ if self.keep_features:
258
+ return results[1:]
259
+ else:
260
+ return results[-1]
basicsr/archs/hifacegan_util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ # Warning: spectral norm could be buggy
7
+ # under eval mode and multi-GPU inference
8
+ # A workaround is sticking to single-GPU inference and train mode
9
+ from torch.nn.utils import spectral_norm
10
+
11
+
12
+ class SPADE(nn.Module):
13
+
14
+ def __init__(self, config_text, norm_nc, label_nc):
15
+ super().__init__()
16
+
17
+ assert config_text.startswith('spade')
18
+ parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
19
+ param_free_norm_type = str(parsed.group(1))
20
+ ks = int(parsed.group(2))
21
+
22
+ if param_free_norm_type == 'instance':
23
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
24
+ elif param_free_norm_type == 'syncbatch':
25
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
26
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
27
+ elif param_free_norm_type == 'batch':
28
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
29
+ else:
30
+ raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
31
+
32
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
33
+ nhidden = 128 if norm_nc > 128 else norm_nc
34
+
35
+ pw = ks // 2
36
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
37
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
38
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
39
+
40
+ def forward(self, x, segmap):
41
+
42
+ # Part 1. generate parameter-free normalized activations
43
+ normalized = self.param_free_norm(x)
44
+
45
+ # Part 2. produce scaling and bias conditioned on semantic map
46
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
47
+ actv = self.mlp_shared(segmap)
48
+ gamma = self.mlp_gamma(actv)
49
+ beta = self.mlp_beta(actv)
50
+
51
+ # apply scale and bias
52
+ out = normalized * gamma + beta
53
+
54
+ return out
55
+
56
+
57
+ class SPADEResnetBlock(nn.Module):
58
+ """
59
+ ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
60
+ it takes in the segmentation map as input, learns the skip connection if necessary,
61
+ and applies normalization first and then convolution.
62
+ This architecture seemed like a standard architecture for unconditional or
63
+ class-conditional GAN architecture using residual block.
64
+ The code was inspired from https://github.com/LMescheder/GAN_stability.
65
+ """
66
+
67
+ def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
68
+ super().__init__()
69
+ # Attributes
70
+ self.learned_shortcut = (fin != fout)
71
+ fmiddle = min(fin, fout)
72
+
73
+ # create conv layers
74
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
75
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
76
+ if self.learned_shortcut:
77
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
78
+
79
+ # apply spectral norm if specified
80
+ if 'spectral' in norm_g:
81
+ self.conv_0 = spectral_norm(self.conv_0)
82
+ self.conv_1 = spectral_norm(self.conv_1)
83
+ if self.learned_shortcut:
84
+ self.conv_s = spectral_norm(self.conv_s)
85
+
86
+ # define normalization layers
87
+ spade_config_str = norm_g.replace('spectral', '')
88
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
89
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
90
+ if self.learned_shortcut:
91
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
92
+
93
+ # note the resnet block with SPADE also takes in |seg|,
94
+ # the semantic segmentation map as input
95
+ def forward(self, x, seg):
96
+ x_s = self.shortcut(x, seg)
97
+ dx = self.conv_0(self.act(self.norm_0(x, seg)))
98
+ dx = self.conv_1(self.act(self.norm_1(dx, seg)))
99
+ out = x_s + dx
100
+ return out
101
+
102
+ def shortcut(self, x, seg):
103
+ if self.learned_shortcut:
104
+ x_s = self.conv_s(self.norm_s(x, seg))
105
+ else:
106
+ x_s = x
107
+ return x_s
108
+
109
+ def act(self, x):
110
+ return F.leaky_relu(x, 2e-1)
111
+
112
+
113
+ class BaseNetwork(nn.Module):
114
+ """ A basis for hifacegan archs with custom initialization """
115
+
116
+ def init_weights(self, init_type='normal', gain=0.02):
117
+
118
+ def init_func(m):
119
+ classname = m.__class__.__name__
120
+ if classname.find('BatchNorm2d') != -1:
121
+ if hasattr(m, 'weight') and m.weight is not None:
122
+ init.normal_(m.weight.data, 1.0, gain)
123
+ if hasattr(m, 'bias') and m.bias is not None:
124
+ init.constant_(m.bias.data, 0.0)
125
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
126
+ if init_type == 'normal':
127
+ init.normal_(m.weight.data, 0.0, gain)
128
+ elif init_type == 'xavier':
129
+ init.xavier_normal_(m.weight.data, gain=gain)
130
+ elif init_type == 'xavier_uniform':
131
+ init.xavier_uniform_(m.weight.data, gain=1.0)
132
+ elif init_type == 'kaiming':
133
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
134
+ elif init_type == 'orthogonal':
135
+ init.orthogonal_(m.weight.data, gain=gain)
136
+ elif init_type == 'none': # uses pytorch's default init method
137
+ m.reset_parameters()
138
+ else:
139
+ raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
140
+ if hasattr(m, 'bias') and m.bias is not None:
141
+ init.constant_(m.bias.data, 0.0)
142
+
143
+ self.apply(init_func)
144
+
145
+ # propagate to children
146
+ for m in self.children():
147
+ if hasattr(m, 'init_weights'):
148
+ m.init_weights(init_type, gain)
149
+
150
+ def forward(self, x):
151
+ pass
152
+
153
+
154
+ def lip2d(x, logit, kernel=3, stride=2, padding=1):
155
+ weight = logit.exp()
156
+ return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
157
+
158
+
159
+ class SoftGate(nn.Module):
160
+ COEFF = 12.0
161
+
162
+ def forward(self, x):
163
+ return torch.sigmoid(x).mul(self.COEFF)
164
+
165
+
166
+ class SimplifiedLIP(nn.Module):
167
+
168
+ def __init__(self, channels):
169
+ super(SimplifiedLIP, self).__init__()
170
+ self.logit = nn.Sequential(
171
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
172
+ SoftGate())
173
+
174
+ def init_layer(self):
175
+ self.logit[0].weight.data.fill_(0.0)
176
+
177
+ def forward(self, x):
178
+ frac = lip2d(x, self.logit(x))
179
+ return frac
180
+
181
+
182
+ class LIPEncoder(BaseNetwork):
183
+ """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
184
+
185
+ def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
186
+ super().__init__()
187
+ self.sw = sw
188
+ self.sh = sh
189
+ self.max_ratio = 16
190
+ # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
191
+ kw = 3
192
+ pw = (kw - 1) // 2
193
+
194
+ model = [
195
+ nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
196
+ norm_layer(ngf),
197
+ nn.ReLU(),
198
+ ]
199
+ cur_ratio = 1
200
+ for i in range(n_2xdown):
201
+ next_ratio = min(cur_ratio * 2, self.max_ratio)
202
+ model += [
203
+ SimplifiedLIP(ngf * cur_ratio),
204
+ nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
205
+ norm_layer(ngf * next_ratio),
206
+ ]
207
+ cur_ratio = next_ratio
208
+ if i < n_2xdown - 1:
209
+ model += [nn.ReLU(inplace=True)]
210
+
211
+ self.model = nn.Sequential(*model)
212
+
213
+ def forward(self, x):
214
+ return self.model(x)
215
+
216
+
217
+ def get_nonspade_norm_layer(norm_type='instance'):
218
+ # helper function to get # output channels of the previous layer
219
+ def get_out_channel(layer):
220
+ if hasattr(layer, 'out_channels'):
221
+ return getattr(layer, 'out_channels')
222
+ return layer.weight.size(0)
223
+
224
+ # this function will be returned
225
+ def add_norm_layer(layer):
226
+ nonlocal norm_type
227
+ if norm_type.startswith('spectral'):
228
+ layer = spectral_norm(layer)
229
+ subnorm_type = norm_type[len('spectral'):]
230
+
231
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
232
+ return layer
233
+
234
+ # remove bias in the previous layer, which is meaningless
235
+ # since it has no effect after normalization
236
+ if getattr(layer, 'bias', None) is not None:
237
+ delattr(layer, 'bias')
238
+ layer.register_parameter('bias', None)
239
+
240
+ if subnorm_type == 'batch':
241
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
242
+ elif subnorm_type == 'sync_batch':
243
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
244
+ # norm_layer = SynchronizedBatchNorm2d(
245
+ # get_out_channel(layer), affine=True)
246
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
247
+ elif subnorm_type == 'instance':
248
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
249
+ else:
250
+ raise ValueError(f'normalization layer {subnorm_type} is not recognized')
251
+
252
+ return nn.Sequential(layer, norm_layer)
253
+
254
+ print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
255
+ return add_norm_layer
basicsr/archs/inception.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
2
+ # For FID metric
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.model_zoo import load_url
9
+ from torchvision import models
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+ LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
15
+
16
+
17
+ class InceptionV3(nn.Module):
18
+ """Pretrained InceptionV3 network returning feature maps"""
19
+
20
+ # Index of default block of inception to return,
21
+ # corresponds to output of final average pooling
22
+ DEFAULT_BLOCK_INDEX = 3
23
+
24
+ # Maps feature dimensionality to their output blocks indices
25
+ BLOCK_INDEX_BY_DIM = {
26
+ 64: 0, # First max pooling features
27
+ 192: 1, # Second max pooling features
28
+ 768: 2, # Pre-aux classifier features
29
+ 2048: 3 # Final average pooling features
30
+ }
31
+
32
+ def __init__(self,
33
+ output_blocks=(DEFAULT_BLOCK_INDEX),
34
+ resize_input=True,
35
+ normalize_input=True,
36
+ requires_grad=False,
37
+ use_fid_inception=True):
38
+ """Build pretrained InceptionV3.
39
+
40
+ Args:
41
+ output_blocks (list[int]): Indices of blocks to return features of.
42
+ Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input (bool): If true, bilinearly resizes input to width and
48
+ height 299 before feeding input to model. As the network
49
+ without fully connected layers is fully convolutional, it
50
+ should be able to handle inputs of arbitrary size, so resizing
51
+ might not be strictly needed. Default: True.
52
+ normalize_input (bool): If true, scales the input from range (0, 1)
53
+ to the range the pretrained Inception network expects,
54
+ namely (-1, 1). Default: True.
55
+ requires_grad (bool): If true, parameters of the model require
56
+ gradients. Possibly useful for finetuning the network.
57
+ Default: False.
58
+ use_fid_inception (bool): If true, uses the pretrained Inception
59
+ model used in Tensorflow's FID implementation.
60
+ If false, uses the pretrained Inception model available in
61
+ torchvision. The FID Inception model has different weights
62
+ and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get
65
+ comparable results. Default: True.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, ('Last possible output block index is 3')
75
+
76
+ self.blocks = nn.ModuleList()
77
+
78
+ if use_fid_inception:
79
+ inception = fid_inception_v3()
80
+ else:
81
+ try:
82
+ inception = models.inception_v3(pretrained=True, init_weights=False)
83
+ except TypeError:
84
+ # pytorch < 1.5 does not have init_weights for inception_v3
85
+ inception = models.inception_v3(pretrained=True)
86
+
87
+ # Block 0: input to maxpool1
88
+ block0 = [
89
+ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
90
+ nn.MaxPool2d(kernel_size=3, stride=2)
91
+ ]
92
+ self.blocks.append(nn.Sequential(*block0))
93
+
94
+ # Block 1: maxpool1 to maxpool2
95
+ if self.last_needed_block >= 1:
96
+ block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
97
+ self.blocks.append(nn.Sequential(*block1))
98
+
99
+ # Block 2: maxpool2 to aux classifier
100
+ if self.last_needed_block >= 2:
101
+ block2 = [
102
+ inception.Mixed_5b,
103
+ inception.Mixed_5c,
104
+ inception.Mixed_5d,
105
+ inception.Mixed_6a,
106
+ inception.Mixed_6b,
107
+ inception.Mixed_6c,
108
+ inception.Mixed_6d,
109
+ inception.Mixed_6e,
110
+ ]
111
+ self.blocks.append(nn.Sequential(*block2))
112
+
113
+ # Block 3: aux classifier to final avgpool
114
+ if self.last_needed_block >= 3:
115
+ block3 = [
116
+ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
117
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
118
+ ]
119
+ self.blocks.append(nn.Sequential(*block3))
120
+
121
+ for param in self.parameters():
122
+ param.requires_grad = requires_grad
123
+
124
+ def forward(self, x):
125
+ """Get Inception feature maps.
126
+
127
+ Args:
128
+ x (Tensor): Input tensor of shape (b, 3, h, w).
129
+ Values are expected to be in range (-1, 1). You can also input
130
+ (0, 1) with setting normalize_input = True.
131
+
132
+ Returns:
133
+ list[Tensor]: Corresponding to the selected output block, sorted
134
+ ascending by index.
135
+ """
136
+ output = []
137
+
138
+ if self.resize_input:
139
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
140
+
141
+ if self.normalize_input:
142
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
143
+
144
+ for idx, block in enumerate(self.blocks):
145
+ x = block(x)
146
+ if idx in self.output_blocks:
147
+ output.append(x)
148
+
149
+ if idx == self.last_needed_block:
150
+ break
151
+
152
+ return output
153
+
154
+
155
+ def fid_inception_v3():
156
+ """Build pretrained Inception model for FID computation.
157
+
158
+ The Inception model for FID computation uses a different set of weights
159
+ and has a slightly different structure than torchvision's Inception.
160
+
161
+ This method first constructs torchvision's Inception and then patches the
162
+ necessary parts that are different in the FID Inception model.
163
+ """
164
+ try:
165
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
166
+ except TypeError:
167
+ # pytorch < 1.5 does not have init_weights for inception_v3
168
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
169
+
170
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
171
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
172
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
173
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
174
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
175
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
176
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
177
+ inception.Mixed_7b = FIDInceptionE_1(1280)
178
+ inception.Mixed_7c = FIDInceptionE_2(2048)
179
+
180
+ if os.path.exists(LOCAL_FID_WEIGHTS):
181
+ state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
182
+ else:
183
+ state_dict = load_url(FID_WEIGHTS_URL, progress=True)
184
+
185
+ inception.load_state_dict(state_dict)
186
+ return inception
187
+
188
+
189
+ class FIDInceptionA(models.inception.InceptionA):
190
+ """InceptionA block patched for FID computation"""
191
+
192
+ def __init__(self, in_channels, pool_features):
193
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
194
+
195
+ def forward(self, x):
196
+ branch1x1 = self.branch1x1(x)
197
+
198
+ branch5x5 = self.branch5x5_1(x)
199
+ branch5x5 = self.branch5x5_2(branch5x5)
200
+
201
+ branch3x3dbl = self.branch3x3dbl_1(x)
202
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
203
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
204
+
205
+ # Patch: Tensorflow's average pool does not use the padded zero's in
206
+ # its average calculation
207
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
208
+ branch_pool = self.branch_pool(branch_pool)
209
+
210
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
211
+ return torch.cat(outputs, 1)
212
+
213
+
214
+ class FIDInceptionC(models.inception.InceptionC):
215
+ """InceptionC block patched for FID computation"""
216
+
217
+ def __init__(self, in_channels, channels_7x7):
218
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
219
+
220
+ def forward(self, x):
221
+ branch1x1 = self.branch1x1(x)
222
+
223
+ branch7x7 = self.branch7x7_1(x)
224
+ branch7x7 = self.branch7x7_2(branch7x7)
225
+ branch7x7 = self.branch7x7_3(branch7x7)
226
+
227
+ branch7x7dbl = self.branch7x7dbl_1(x)
228
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
229
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
230
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
231
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
232
+
233
+ # Patch: Tensorflow's average pool does not use the padded zero's in
234
+ # its average calculation
235
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
236
+ branch_pool = self.branch_pool(branch_pool)
237
+
238
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
239
+ return torch.cat(outputs, 1)
240
+
241
+
242
+ class FIDInceptionE_1(models.inception.InceptionE):
243
+ """First InceptionE block patched for FID computation"""
244
+
245
+ def __init__(self, in_channels):
246
+ super(FIDInceptionE_1, self).__init__(in_channels)
247
+
248
+ def forward(self, x):
249
+ branch1x1 = self.branch1x1(x)
250
+
251
+ branch3x3 = self.branch3x3_1(x)
252
+ branch3x3 = [
253
+ self.branch3x3_2a(branch3x3),
254
+ self.branch3x3_2b(branch3x3),
255
+ ]
256
+ branch3x3 = torch.cat(branch3x3, 1)
257
+
258
+ branch3x3dbl = self.branch3x3dbl_1(x)
259
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
260
+ branch3x3dbl = [
261
+ self.branch3x3dbl_3a(branch3x3dbl),
262
+ self.branch3x3dbl_3b(branch3x3dbl),
263
+ ]
264
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
265
+
266
+ # Patch: Tensorflow's average pool does not use the padded zero's in
267
+ # its average calculation
268
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
269
+ branch_pool = self.branch_pool(branch_pool)
270
+
271
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
272
+ return torch.cat(outputs, 1)
273
+
274
+
275
+ class FIDInceptionE_2(models.inception.InceptionE):
276
+ """Second InceptionE block patched for FID computation"""
277
+
278
+ def __init__(self, in_channels):
279
+ super(FIDInceptionE_2, self).__init__(in_channels)
280
+
281
+ def forward(self, x):
282
+ branch1x1 = self.branch1x1(x)
283
+
284
+ branch3x3 = self.branch3x3_1(x)
285
+ branch3x3 = [
286
+ self.branch3x3_2a(branch3x3),
287
+ self.branch3x3_2b(branch3x3),
288
+ ]
289
+ branch3x3 = torch.cat(branch3x3, 1)
290
+
291
+ branch3x3dbl = self.branch3x3dbl_1(x)
292
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
293
+ branch3x3dbl = [
294
+ self.branch3x3dbl_3a(branch3x3dbl),
295
+ self.branch3x3dbl_3b(branch3x3dbl),
296
+ ]
297
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
298
+
299
+ # Patch: The FID Inception model uses max pooling instead of average
300
+ # pooling. This is likely an error in this specific Inception
301
+ # implementation, as other Inception models use average pooling here
302
+ # (which matches the description in the paper).
303
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
304
+ branch_pool = self.branch_pool(branch_pool)
305
+
306
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
+ return torch.cat(outputs, 1)
basicsr/archs/rcan_arch.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from .arch_util import Upsample, make_layer
6
+
7
+
8
+ class ChannelAttention(nn.Module):
9
+ """Channel attention used in RCAN.
10
+
11
+ Args:
12
+ num_feat (int): Channel number of intermediate features.
13
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
14
+ """
15
+
16
+ def __init__(self, num_feat, squeeze_factor=16):
17
+ super(ChannelAttention, self).__init__()
18
+ self.attention = nn.Sequential(
19
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
20
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
21
+
22
+ def forward(self, x):
23
+ y = self.attention(x)
24
+ return x * y
25
+
26
+
27
+ class RCAB(nn.Module):
28
+ """Residual Channel Attention Block (RCAB) used in RCAN.
29
+
30
+ Args:
31
+ num_feat (int): Channel number of intermediate features.
32
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
33
+ res_scale (float): Scale the residual. Default: 1.
34
+ """
35
+
36
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
37
+ super(RCAB, self).__init__()
38
+ self.res_scale = res_scale
39
+
40
+ self.rcab = nn.Sequential(
41
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
42
+ ChannelAttention(num_feat, squeeze_factor))
43
+
44
+ def forward(self, x):
45
+ res = self.rcab(x) * self.res_scale
46
+ return res + x
47
+
48
+
49
+ class ResidualGroup(nn.Module):
50
+ """Residual Group of RCAB.
51
+
52
+ Args:
53
+ num_feat (int): Channel number of intermediate features.
54
+ num_block (int): Block number in the body network.
55
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
56
+ res_scale (float): Scale the residual. Default: 1.
57
+ """
58
+
59
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
60
+ super(ResidualGroup, self).__init__()
61
+
62
+ self.residual_group = make_layer(
63
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
64
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
65
+
66
+ def forward(self, x):
67
+ res = self.conv(self.residual_group(x))
68
+ return res + x
69
+
70
+
71
+ @ARCH_REGISTRY.register()
72
+ class RCAN(nn.Module):
73
+ """Residual Channel Attention Networks.
74
+
75
+ ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
76
+
77
+ Reference: https://github.com/yulunzhang/RCAN
78
+
79
+ Args:
80
+ num_in_ch (int): Channel number of inputs.
81
+ num_out_ch (int): Channel number of outputs.
82
+ num_feat (int): Channel number of intermediate features.
83
+ Default: 64.
84
+ num_group (int): Number of ResidualGroup. Default: 10.
85
+ num_block (int): Number of RCAB in ResidualGroup. Default: 16.
86
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
87
+ upscale (int): Upsampling factor. Support 2^n and 3.
88
+ Default: 4.
89
+ res_scale (float): Used to scale the residual in residual block.
90
+ Default: 1.
91
+ img_range (float): Image range. Default: 255.
92
+ rgb_mean (tuple[float]): Image mean in RGB orders.
93
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
94
+ """
95
+
96
+ def __init__(self,
97
+ num_in_ch,
98
+ num_out_ch,
99
+ num_feat=64,
100
+ num_group=10,
101
+ num_block=16,
102
+ squeeze_factor=16,
103
+ upscale=4,
104
+ res_scale=1,
105
+ img_range=255.,
106
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
107
+ super(RCAN, self).__init__()
108
+
109
+ self.img_range = img_range
110
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
111
+
112
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
113
+ self.body = make_layer(
114
+ ResidualGroup,
115
+ num_group,
116
+ num_feat=num_feat,
117
+ num_block=num_block,
118
+ squeeze_factor=squeeze_factor,
119
+ res_scale=res_scale)
120
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
+ self.upsample = Upsample(upscale, num_feat)
122
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
123
+
124
+ def forward(self, x):
125
+ self.mean = self.mean.type_as(x)
126
+
127
+ x = (x - self.mean) * self.img_range
128
+ x = self.conv_first(x)
129
+ res = self.conv_after_body(self.body(x))
130
+ res += x
131
+
132
+ x = self.conv_last(self.upsample(res))
133
+ x = x / self.img_range + self.mean
134
+
135
+ return x
basicsr/archs/ridnet_arch.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from .arch_util import ResidualBlockNoBN, make_layer
6
+
7
+
8
+ class MeanShift(nn.Conv2d):
9
+ """ Data normalization with mean and std.
10
+
11
+ Args:
12
+ rgb_range (int): Maximum value of RGB.
13
+ rgb_mean (list[float]): Mean for RGB channels.
14
+ rgb_std (list[float]): Std for RGB channels.
15
+ sign (int): For subtraction, sign is -1, for addition, sign is 1.
16
+ Default: -1.
17
+ requires_grad (bool): Whether to update the self.weight and self.bias.
18
+ Default: True.
19
+ """
20
+
21
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
22
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
23
+ std = torch.Tensor(rgb_std)
24
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
25
+ self.weight.data.div_(std.view(3, 1, 1, 1))
26
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
27
+ self.bias.data.div_(std)
28
+ self.requires_grad = requires_grad
29
+
30
+
31
+ class EResidualBlockNoBN(nn.Module):
32
+ """Enhanced Residual block without BN.
33
+
34
+ There are three convolution layers in residual branch.
35
+ """
36
+
37
+ def __init__(self, in_channels, out_channels):
38
+ super(EResidualBlockNoBN, self).__init__()
39
+
40
+ self.body = nn.Sequential(
41
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
44
+ nn.ReLU(inplace=True),
45
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
46
+ )
47
+ self.relu = nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ out = self.body(x)
51
+ out = self.relu(out + x)
52
+ return out
53
+
54
+
55
+ class MergeRun(nn.Module):
56
+ """ Merge-and-run unit.
57
+
58
+ This unit contains two branches with different dilated convolutions,
59
+ followed by a convolution to process the concatenated features.
60
+
61
+ Paper: Real Image Denoising with Feature Attention
62
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
63
+ """
64
+
65
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
66
+ super(MergeRun, self).__init__()
67
+
68
+ self.dilation1 = nn.Sequential(
69
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
70
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
71
+ self.dilation2 = nn.Sequential(
72
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
73
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
74
+
75
+ self.aggregation = nn.Sequential(
76
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
77
+
78
+ def forward(self, x):
79
+ dilation1 = self.dilation1(x)
80
+ dilation2 = self.dilation2(x)
81
+ out = torch.cat([dilation1, dilation2], dim=1)
82
+ out = self.aggregation(out)
83
+ out = out + x
84
+ return out
85
+
86
+
87
+ class ChannelAttention(nn.Module):
88
+ """Channel attention.
89
+
90
+ Args:
91
+ num_feat (int): Channel number of intermediate features.
92
+ squeeze_factor (int): Channel squeeze factor. Default:
93
+ """
94
+
95
+ def __init__(self, mid_channels, squeeze_factor=16):
96
+ super(ChannelAttention, self).__init__()
97
+ self.attention = nn.Sequential(
98
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
99
+ nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
100
+
101
+ def forward(self, x):
102
+ y = self.attention(x)
103
+ return x * y
104
+
105
+
106
+ class EAM(nn.Module):
107
+ """Enhancement attention modules (EAM) in RIDNet.
108
+
109
+ This module contains a merge-and-run unit, a residual block,
110
+ an enhanced residual block and a feature attention unit.
111
+
112
+ Attributes:
113
+ merge: The merge-and-run unit.
114
+ block1: The residual block.
115
+ block2: The enhanced residual block.
116
+ ca: The feature/channel attention unit.
117
+ """
118
+
119
+ def __init__(self, in_channels, mid_channels, out_channels):
120
+ super(EAM, self).__init__()
121
+
122
+ self.merge = MergeRun(in_channels, mid_channels)
123
+ self.block1 = ResidualBlockNoBN(mid_channels)
124
+ self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
125
+ self.ca = ChannelAttention(out_channels)
126
+ # The residual block in the paper contains a relu after addition.
127
+ self.relu = nn.ReLU(inplace=True)
128
+
129
+ def forward(self, x):
130
+ out = self.merge(x)
131
+ out = self.relu(self.block1(out))
132
+ out = self.block2(out)
133
+ out = self.ca(out)
134
+ return out
135
+
136
+
137
+ @ARCH_REGISTRY.register()
138
+ class RIDNet(nn.Module):
139
+ """RIDNet: Real Image Denoising with Feature Attention.
140
+
141
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
142
+
143
+ Args:
144
+ in_channels (int): Channel number of inputs.
145
+ mid_channels (int): Channel number of EAM modules.
146
+ Default: 64.
147
+ out_channels (int): Channel number of outputs.
148
+ num_block (int): Number of EAM. Default: 4.
149
+ img_range (float): Image range. Default: 255.
150
+ rgb_mean (tuple[float]): Image mean in RGB orders.
151
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
152
+ """
153
+
154
+ def __init__(self,
155
+ in_channels,
156
+ mid_channels,
157
+ out_channels,
158
+ num_block=4,
159
+ img_range=255.,
160
+ rgb_mean=(0.4488, 0.4371, 0.4040),
161
+ rgb_std=(1.0, 1.0, 1.0)):
162
+ super(RIDNet, self).__init__()
163
+
164
+ self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
165
+ self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
166
+
167
+ self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
168
+ self.body = make_layer(
169
+ EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
170
+ self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
171
+
172
+ self.relu = nn.ReLU(inplace=True)
173
+
174
+ def forward(self, x):
175
+ res = self.sub_mean(x)
176
+ res = self.tail(self.body(self.relu(self.head(res))))
177
+ res = self.add_mean(res)
178
+
179
+ out = x + res
180
+ return out