Spaces:
Running
Running
add SRFlow with srflow.py
Browse files- models/SRFlow/35000_G.pth +3 -0
- models/SRFlow/code/Measure.py +134 -0
- models/SRFlow/code/a.py +27 -0
- models/SRFlow/code/confs/RRDB_CelebA_8X.yml +83 -0
- models/SRFlow/code/confs/RRDB_DF2K_4X.yml +85 -0
- models/SRFlow/code/confs/RRDB_DF2K_8X.yml +85 -0
- models/SRFlow/code/confs/SRFlow_CelebA_8X.yml +107 -0
- models/SRFlow/code/confs/SRFlow_DF2K_4X.yml +106 -0
- models/SRFlow/code/confs/SRFlow_DF2K_8X.yml +112 -0
- models/SRFlow/code/data/LRHR_PKL_dataset.py +179 -0
- models/SRFlow/code/data/__init__.py +51 -0
- models/SRFlow/code/demo_on_pretrained.ipynb +0 -0
- models/SRFlow/code/imresize.py +180 -0
- models/SRFlow/code/models/SRFlow_model.py +278 -0
- models/SRFlow/code/models/SR_model.py +217 -0
- models/SRFlow/code/models/__init__.py +52 -0
- models/SRFlow/code/models/base_model.py +154 -0
- models/SRFlow/code/models/lr_scheduler.py +163 -0
- models/SRFlow/code/models/modules/FlowActNorms.py +141 -0
- models/SRFlow/code/models/modules/FlowAffineCouplingsAblation.py +135 -0
- models/SRFlow/code/models/modules/FlowStep.py +137 -0
- models/SRFlow/code/models/modules/FlowUpsamplerNet.py +309 -0
- models/SRFlow/code/models/modules/Permutations.py +58 -0
- models/SRFlow/code/models/modules/RRDBNet_arch.py +148 -0
- models/SRFlow/code/models/modules/SRFlowNet_arch.py +158 -0
- models/SRFlow/code/models/modules/Split.py +86 -0
- models/SRFlow/code/models/modules/__init__.py +0 -0
- models/SRFlow/code/models/modules/flow.py +166 -0
- models/SRFlow/code/models/modules/glow_arch.py +28 -0
- models/SRFlow/code/models/modules/loss.py +90 -0
- models/SRFlow/code/models/modules/module_util.py +95 -0
- models/SRFlow/code/models/modules/thops.py +68 -0
- models/SRFlow/code/models/networks.py +105 -0
- models/SRFlow/code/options/__init__.py +0 -0
- models/SRFlow/code/options/options.py +146 -0
- models/SRFlow/code/prepare_data.py +118 -0
- models/SRFlow/code/test.py +192 -0
- models/SRFlow/code/train.py +328 -0
- models/SRFlow/code/utils/__init__.py +0 -0
- models/SRFlow/code/utils/timer.py +78 -0
- models/SRFlow/code/utils/util.py +174 -0
- models/SRFlow/srflow.py +27 -0
models/SRFlow/35000_G.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:040fcffde66ec3ef658a843d58832b8aa153734a2f04342d841e3018b498a511
|
3 |
+
size 158819348
|
models/SRFlow/code/Measure.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import glob
|
16 |
+
import os
|
17 |
+
import time
|
18 |
+
from collections import OrderedDict
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import cv2
|
23 |
+
import argparse
|
24 |
+
|
25 |
+
from natsort import natsort
|
26 |
+
from skimage.metrics import structural_similarity as ssim
|
27 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
28 |
+
import lpips
|
29 |
+
|
30 |
+
|
31 |
+
class Measure():
|
32 |
+
def __init__(self, net='alex', use_gpu=False):
|
33 |
+
self.device = 'cuda' if use_gpu else 'cpu'
|
34 |
+
self.model = lpips.LPIPS(net=net)
|
35 |
+
self.model.to(self.device)
|
36 |
+
|
37 |
+
def measure(self, imgA, imgB):
|
38 |
+
return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
|
39 |
+
|
40 |
+
def lpips(self, imgA, imgB, model=None):
|
41 |
+
tA = t(imgA).to(self.device)
|
42 |
+
tB = t(imgB).to(self.device)
|
43 |
+
dist01 = self.model.forward(tA, tB).item()
|
44 |
+
return dist01
|
45 |
+
|
46 |
+
def ssim(self, imgA, imgB):
|
47 |
+
# multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
|
48 |
+
score, diff = ssim(imgA, imgB, full=True, multichannel=True, channel_axis=-1)
|
49 |
+
return score
|
50 |
+
|
51 |
+
def psnr(self, imgA, imgB):
|
52 |
+
psnr_val = psnr(imgA, imgB)
|
53 |
+
return psnr_val
|
54 |
+
|
55 |
+
|
56 |
+
def t(img):
|
57 |
+
def to_4d(img):
|
58 |
+
assert len(img.shape) == 3
|
59 |
+
assert img.dtype == np.uint8
|
60 |
+
img_new = np.expand_dims(img, axis=0)
|
61 |
+
assert len(img_new.shape) == 4
|
62 |
+
return img_new
|
63 |
+
|
64 |
+
def to_CHW(img):
|
65 |
+
return np.transpose(img, [2, 0, 1])
|
66 |
+
|
67 |
+
def to_tensor(img):
|
68 |
+
return torch.Tensor(img)
|
69 |
+
|
70 |
+
return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
|
71 |
+
|
72 |
+
|
73 |
+
def fiFindByWildcard(wildcard):
|
74 |
+
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
75 |
+
|
76 |
+
|
77 |
+
def imread(path):
|
78 |
+
return cv2.imread(path)[:, :, [2, 1, 0]]
|
79 |
+
|
80 |
+
|
81 |
+
def format_result(psnr, ssim, lpips):
|
82 |
+
return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
|
83 |
+
|
84 |
+
def measure_dirs(dirA, dirB, use_gpu, verbose=False):
|
85 |
+
if verbose:
|
86 |
+
vprint = lambda x: print(x)
|
87 |
+
else:
|
88 |
+
vprint = lambda x: None
|
89 |
+
|
90 |
+
|
91 |
+
t_init = time.time()
|
92 |
+
|
93 |
+
paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
|
94 |
+
paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
|
95 |
+
|
96 |
+
vprint("Comparing: ")
|
97 |
+
vprint(dirA)
|
98 |
+
vprint(dirB)
|
99 |
+
|
100 |
+
measure = Measure(use_gpu=use_gpu)
|
101 |
+
|
102 |
+
results = []
|
103 |
+
for pathA, pathB in zip(paths_A, paths_B):
|
104 |
+
result = OrderedDict()
|
105 |
+
|
106 |
+
t = time.time()
|
107 |
+
result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
|
108 |
+
d = time.time() - t
|
109 |
+
vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
|
110 |
+
|
111 |
+
results.append(result)
|
112 |
+
|
113 |
+
psnr = np.mean([result['psnr'] for result in results])
|
114 |
+
ssim = np.mean([result['ssim'] for result in results])
|
115 |
+
lpips = np.mean([result['lpips'] for result in results])
|
116 |
+
|
117 |
+
vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument('-dirA', default='', type=str)
|
123 |
+
parser.add_argument('-dirB', default='', type=str)
|
124 |
+
parser.add_argument('-type', default='png')
|
125 |
+
parser.add_argument('--use_gpu', action='store_true', default=False)
|
126 |
+
args = parser.parse_args()
|
127 |
+
|
128 |
+
dirA = args.dirA
|
129 |
+
dirB = args.dirB
|
130 |
+
type = args.type
|
131 |
+
use_gpu = args.use_gpu
|
132 |
+
|
133 |
+
if len(dirA) > 0 and len(dirB) > 0:
|
134 |
+
measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
|
models/SRFlow/code/a.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
def load_pkls(path):
|
7 |
+
assert os.path.isfile(path), path
|
8 |
+
images = []
|
9 |
+
with open(path, "rb") as f:
|
10 |
+
images += pickle.load(f)
|
11 |
+
assert len(images) > 0, path
|
12 |
+
images = [np.transpose(image, [2, 0, 1]) for image in images]
|
13 |
+
return images
|
14 |
+
|
15 |
+
path = 'datasets/DIV2K-va.pklv4'
|
16 |
+
loaded_images = load_pkls(path)
|
17 |
+
print(len(loaded_images))
|
18 |
+
# Display the first image
|
19 |
+
if loaded_images:
|
20 |
+
first_image = loaded_images[11]
|
21 |
+
plt.imshow(np.transpose(first_image, [1, 2, 0])) # Transpose image to original shape [height, width, channels]
|
22 |
+
plt.title('First Image')
|
23 |
+
plt.axis('off') # Hide axis
|
24 |
+
plt.show()
|
25 |
+
else:
|
26 |
+
print("No images loaded from the pickle file.")
|
27 |
+
print(loaded_images[11])
|
models/SRFlow/code/confs/RRDB_CelebA_8X.yml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SR
|
21 |
+
distortion: sr
|
22 |
+
scale: 8
|
23 |
+
#gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: CelebA_160_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: ../datasets/celebA-train-gt_1pct.pklv4
|
31 |
+
dataroot_LQ: ../datasets/celebA-train-x8_1pct.pklv4
|
32 |
+
|
33 |
+
use_shuffle: true
|
34 |
+
n_workers: 0 # per GPU
|
35 |
+
batch_size: 16
|
36 |
+
GT_size: 160
|
37 |
+
use_flip: true
|
38 |
+
use_rot: true
|
39 |
+
color: RGB
|
40 |
+
val:
|
41 |
+
name: CelebA_160_va
|
42 |
+
mode: LRHR_PKL
|
43 |
+
dataroot_GT: ../datasets/celebA-valid-gt_1pct.pklv4
|
44 |
+
dataroot_LQ: ../datasets/celebA-valid-x8_1pct.pklv4
|
45 |
+
n_max: 10
|
46 |
+
|
47 |
+
#### network structures
|
48 |
+
network_G:
|
49 |
+
which_model_G: RRDBNet
|
50 |
+
in_nc: 3
|
51 |
+
out_nc: 3
|
52 |
+
nf: 64
|
53 |
+
nb: 23
|
54 |
+
|
55 |
+
#### path
|
56 |
+
path:
|
57 |
+
pretrain_model_G: ~
|
58 |
+
strict_load: true
|
59 |
+
resume_state: auto
|
60 |
+
|
61 |
+
#### training settings: learning rate scheme, loss
|
62 |
+
train:
|
63 |
+
lr_G: !!float 2e-4
|
64 |
+
lr_scheme: CosineAnnealingLR_Restart
|
65 |
+
beta1: 0.9
|
66 |
+
beta2: 0.99
|
67 |
+
niter: 200000
|
68 |
+
warmup_iter: -1 # no warm up
|
69 |
+
T_period: [ 50000, 50000, 50000, 50000 ]
|
70 |
+
restarts: [ 50000, 100000, 150000 ]
|
71 |
+
restart_weights: [ 1, 1, 1 ]
|
72 |
+
eta_min: !!float 1e-7
|
73 |
+
|
74 |
+
pixel_criterion: l1
|
75 |
+
pixel_weight: 1.0
|
76 |
+
|
77 |
+
manual_seed: 10
|
78 |
+
val_freq: !!float 5e3
|
79 |
+
|
80 |
+
#### logger
|
81 |
+
logger:
|
82 |
+
print_freq: 100
|
83 |
+
save_checkpoint_freq: !!float 1e3
|
models/SRFlow/code/confs/RRDB_DF2K_4X.yml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SR
|
21 |
+
distortion: sr
|
22 |
+
scale: 4
|
23 |
+
gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: CelebA_160_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
|
31 |
+
dataroot_LQ: ../datasets/DF2K-train-x4_1pct.pklv4
|
32 |
+
quant: 32
|
33 |
+
|
34 |
+
use_shuffle: true
|
35 |
+
n_workers: 3 # per GPU
|
36 |
+
batch_size: 16
|
37 |
+
GT_size: 160
|
38 |
+
use_flip: true
|
39 |
+
color: RGB
|
40 |
+
val:
|
41 |
+
name: CelebA_160_va
|
42 |
+
mode: LRHR_PKL
|
43 |
+
dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
|
44 |
+
dataroot_LQ: ../datasets/DF2K-valid-x4_1pct.pklv4
|
45 |
+
quant: 32
|
46 |
+
n_max: 20
|
47 |
+
|
48 |
+
#### network structures
|
49 |
+
network_G:
|
50 |
+
which_model_G: RRDBNet
|
51 |
+
use_orig: True
|
52 |
+
in_nc: 3
|
53 |
+
out_nc: 3
|
54 |
+
nf: 64
|
55 |
+
nb: 23
|
56 |
+
|
57 |
+
#### path
|
58 |
+
path:
|
59 |
+
pretrain_model_G: ~
|
60 |
+
strict_load: true
|
61 |
+
resume_state: auto
|
62 |
+
|
63 |
+
#### training settings: learning rate scheme, loss
|
64 |
+
train:
|
65 |
+
lr_G: !!float 2e-4
|
66 |
+
lr_scheme: CosineAnnealingLR_Restart
|
67 |
+
beta1: 0.9
|
68 |
+
beta2: 0.99
|
69 |
+
niter: 1000000
|
70 |
+
warmup_iter: -1 # no warm up
|
71 |
+
T_period: [ 50000, 50000, 50000, 50000 ]
|
72 |
+
restarts: [ 50000, 100000, 150000 ]
|
73 |
+
restart_weights: [ 1, 1, 1 ]
|
74 |
+
eta_min: !!float 1e-7
|
75 |
+
|
76 |
+
pixel_criterion: l1
|
77 |
+
pixel_weight: 1.0
|
78 |
+
|
79 |
+
manual_seed: 10
|
80 |
+
val_freq: !!float 5e3
|
81 |
+
|
82 |
+
#### logger
|
83 |
+
logger:
|
84 |
+
print_freq: 100
|
85 |
+
save_checkpoint_freq: !!float 1e3
|
models/SRFlow/code/confs/RRDB_DF2K_8X.yml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SR
|
21 |
+
distortion: sr
|
22 |
+
scale: 8
|
23 |
+
gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: CelebA_160_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
|
31 |
+
dataroot_LQ: ../datasets/DF2K-train-x8_1pct.pklv4
|
32 |
+
quant: 32
|
33 |
+
|
34 |
+
use_shuffle: true
|
35 |
+
n_workers: 3 # per GPU
|
36 |
+
batch_size: 16
|
37 |
+
GT_size: 160
|
38 |
+
use_flip: true
|
39 |
+
color: RGB
|
40 |
+
|
41 |
+
val:
|
42 |
+
name: CelebA_160_va
|
43 |
+
mode: LRHR_PKL
|
44 |
+
dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
|
45 |
+
dataroot_LQ: ../datasets/DF2K-valid-x8_1pct.pklv4
|
46 |
+
quant: 32
|
47 |
+
n_max: 20
|
48 |
+
|
49 |
+
#### network structures
|
50 |
+
network_G:
|
51 |
+
which_model_G: RRDBNet
|
52 |
+
in_nc: 3
|
53 |
+
out_nc: 3
|
54 |
+
nf: 64
|
55 |
+
nb: 23
|
56 |
+
|
57 |
+
#### path
|
58 |
+
path:
|
59 |
+
pretrain_model_G: ~
|
60 |
+
strict_load: true
|
61 |
+
resume_state: auto
|
62 |
+
|
63 |
+
#### training settings: learning rate scheme, loss
|
64 |
+
train:
|
65 |
+
lr_G: !!float 2e-4
|
66 |
+
lr_scheme: CosineAnnealingLR_Restart
|
67 |
+
beta1: 0.9
|
68 |
+
beta2: 0.99
|
69 |
+
niter: 200000
|
70 |
+
warmup_iter: -1 # no warm up
|
71 |
+
T_period: [ 50000, 50000, 50000, 50000 ]
|
72 |
+
restarts: [ 50000, 100000, 150000 ]
|
73 |
+
restart_weights: [ 1, 1, 1 ]
|
74 |
+
eta_min: !!float 1e-7
|
75 |
+
|
76 |
+
pixel_criterion: l1
|
77 |
+
pixel_weight: 1.0
|
78 |
+
|
79 |
+
manual_seed: 10
|
80 |
+
val_freq: !!float 5e3
|
81 |
+
|
82 |
+
#### logger
|
83 |
+
logger:
|
84 |
+
print_freq: 100
|
85 |
+
save_checkpoint_freq: !!float 1e3
|
models/SRFlow/code/confs/SRFlow_CelebA_8X.yml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SRFlow
|
21 |
+
distortion: sr
|
22 |
+
scale: 8
|
23 |
+
gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: CelebA_160_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: ../datasets/celebA-train-gt.pklv4
|
31 |
+
dataroot_LQ: ../datasets/celebA-train-x8.pklv4
|
32 |
+
quant: 32
|
33 |
+
|
34 |
+
use_shuffle: true
|
35 |
+
n_workers: 3 # per GPU
|
36 |
+
batch_size: 16
|
37 |
+
GT_size: 160
|
38 |
+
use_flip: true
|
39 |
+
color: RGB
|
40 |
+
val:
|
41 |
+
name: CelebA_160_va
|
42 |
+
mode: LRHR_PKL
|
43 |
+
dataroot_GT: ../datasets/celebA-train-gt.pklv4
|
44 |
+
dataroot_LQ: ../datasets/celebA-train-x8.pklv4
|
45 |
+
quant: 32
|
46 |
+
n_max: 20
|
47 |
+
|
48 |
+
#### Test Settings
|
49 |
+
dataroot_GT: ../datasets/celebA-validation-gt
|
50 |
+
dataroot_LR: ../datasets/celebA-validation-x8
|
51 |
+
model_path: ../pretrained_models/SRFlow_CelebA_8X.pth
|
52 |
+
heat: 0.9 # This is the standard deviation of the latent vectors
|
53 |
+
|
54 |
+
#### network structures
|
55 |
+
network_G:
|
56 |
+
which_model_G: SRFlowNet
|
57 |
+
in_nc: 3
|
58 |
+
out_nc: 3
|
59 |
+
nf: 64
|
60 |
+
nb: 8
|
61 |
+
upscale: 8
|
62 |
+
train_RRDB: false
|
63 |
+
train_RRDB_delay: 0.5
|
64 |
+
|
65 |
+
flow:
|
66 |
+
K: 16
|
67 |
+
L: 4
|
68 |
+
noInitialInj: true
|
69 |
+
coupling: CondAffineSeparatedAndCond
|
70 |
+
additionalFlowNoAffine: 2
|
71 |
+
split:
|
72 |
+
enable: true
|
73 |
+
fea_up0: true
|
74 |
+
stackRRDB:
|
75 |
+
blocks: [ 1, 3, 5, 7 ]
|
76 |
+
concat: true
|
77 |
+
|
78 |
+
#### path
|
79 |
+
path:
|
80 |
+
pretrain_model_G: ../pretrained_models/RRDB_CelebA_8X.pth
|
81 |
+
strict_load: true
|
82 |
+
resume_state: auto
|
83 |
+
|
84 |
+
#### training settings: learning rate scheme, loss
|
85 |
+
train:
|
86 |
+
manual_seed: 10
|
87 |
+
lr_G: !!float 5e-4
|
88 |
+
weight_decay_G: 0
|
89 |
+
beta1: 0.9
|
90 |
+
beta2: 0.99
|
91 |
+
lr_scheme: MultiStepLR
|
92 |
+
warmup_iter: -1 # no warm up
|
93 |
+
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
|
94 |
+
lr_gamma: 0.5
|
95 |
+
|
96 |
+
niter: 200000
|
97 |
+
val_freq: 40000
|
98 |
+
|
99 |
+
#### validation settings
|
100 |
+
val:
|
101 |
+
heats: [ 0.0, 0.5, 0.75, 1.0 ]
|
102 |
+
n_sample: 3
|
103 |
+
|
104 |
+
#### logger
|
105 |
+
logger:
|
106 |
+
print_freq: 100
|
107 |
+
save_checkpoint_freq: !!float 1e3
|
models/SRFlow/code/confs/SRFlow_DF2K_4X.yml
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SRFlow
|
21 |
+
distortion: sr
|
22 |
+
scale: 4
|
23 |
+
gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: DF2K_256_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: /kaggle/input/srflow0103/SRFlow/datasets/DF2K-tr.pklv4
|
31 |
+
dataroot_LQ: /kaggle/input/srflow0103/SRFlow/datasets/DF2K-tr_X4.pklv4
|
32 |
+
quant: 32
|
33 |
+
|
34 |
+
use_shuffle: true
|
35 |
+
n_workers: 3 # per GPU
|
36 |
+
batch_size: 12
|
37 |
+
GT_size: 256
|
38 |
+
use_flip: true
|
39 |
+
color: RGB
|
40 |
+
val:
|
41 |
+
name: DF2K_256_tr
|
42 |
+
mode: LRHR_PKL
|
43 |
+
dataroot_GT: ../datasets/DIV2K-va.pklv4
|
44 |
+
dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4
|
45 |
+
quant: 32
|
46 |
+
n_max: 20
|
47 |
+
|
48 |
+
#### Test Settings
|
49 |
+
dataroot: /kaggle/input/test-set/test set
|
50 |
+
model_path: /models/SRFlow/35000_G
|
51 |
+
heat: 0.6 # This is the standard deviation of the latent vectors
|
52 |
+
|
53 |
+
#### network structures
|
54 |
+
network_G:
|
55 |
+
which_model_G: SRFlowNet
|
56 |
+
in_nc: 3
|
57 |
+
out_nc: 3
|
58 |
+
nf: 64
|
59 |
+
nb: 23
|
60 |
+
upscale: 4
|
61 |
+
train_RRDB: false
|
62 |
+
train_RRDB_delay: 0.5
|
63 |
+
|
64 |
+
flow:
|
65 |
+
K: 16
|
66 |
+
L: 3
|
67 |
+
noInitialInj: true
|
68 |
+
coupling: CondAffineSeparatedAndCond
|
69 |
+
additionalFlowNoAffine: 2
|
70 |
+
split:
|
71 |
+
enable: true
|
72 |
+
fea_up0: true
|
73 |
+
stackRRDB:
|
74 |
+
blocks: [ 1, 8, 15, 22 ]
|
75 |
+
concat: true
|
76 |
+
|
77 |
+
#### path
|
78 |
+
path:
|
79 |
+
pretrain_model_G:
|
80 |
+
strict_load: true
|
81 |
+
resume_state: auto
|
82 |
+
|
83 |
+
#### training settings: learning rate scheme, loss
|
84 |
+
train:
|
85 |
+
manual_seed: 10
|
86 |
+
lr_G: !!float 2.5e-4
|
87 |
+
weight_decay_G: 0
|
88 |
+
beta1: 0.9
|
89 |
+
beta2: 0.99
|
90 |
+
lr_scheme: MultiStepLR
|
91 |
+
warmup_iter: -1 # no warm up
|
92 |
+
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
|
93 |
+
lr_gamma: 0.5
|
94 |
+
|
95 |
+
niter: 64185
|
96 |
+
val_freq: 40000
|
97 |
+
|
98 |
+
#### validation settings
|
99 |
+
val:
|
100 |
+
heats: [ 0.0, 0.5, 0.75, 1.0 ]
|
101 |
+
n_sample: 3
|
102 |
+
|
103 |
+
#### logger
|
104 |
+
logger:
|
105 |
+
print_freq: 100
|
106 |
+
save_checkpoint_freq: !!float 5e3
|
models/SRFlow/code/confs/SRFlow_DF2K_8X.yml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
#### general settings
|
18 |
+
name: train
|
19 |
+
use_tb_logger: true
|
20 |
+
model: SRFlow
|
21 |
+
distortion: sr
|
22 |
+
scale: 8
|
23 |
+
gpu_ids: [ 0 ]
|
24 |
+
|
25 |
+
#### datasets
|
26 |
+
datasets:
|
27 |
+
train:
|
28 |
+
name: CelebA_160_tr
|
29 |
+
mode: LRHR_PKL
|
30 |
+
dataroot_GT: ../datasets/DF2K-tr.pklv4
|
31 |
+
dataroot_LQ: ../datasets/DF2K-tr_X8.pklv4
|
32 |
+
quant: 32
|
33 |
+
|
34 |
+
use_shuffle: true
|
35 |
+
n_workers: 3 # per GPU
|
36 |
+
batch_size: 16
|
37 |
+
GT_size: 160
|
38 |
+
use_flip: true
|
39 |
+
color: RGB
|
40 |
+
|
41 |
+
val:
|
42 |
+
name: CelebA_160_va
|
43 |
+
mode: LRHR_PKL
|
44 |
+
dataroot_GT: ../datasets/DIV2K-va.pklv4
|
45 |
+
dataroot_LQ: ../datasets/DIV2K-va_X8.pklv4
|
46 |
+
quant: 32
|
47 |
+
n_max: 20
|
48 |
+
|
49 |
+
#### Test Settings
|
50 |
+
dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
|
51 |
+
dataroot_LR: ../datasets/div2k-validation-modcrop8-x8
|
52 |
+
model_path: ../pretrained_models/SRFlow_DF2K_8X.pth
|
53 |
+
heat: 0.9 # This is the standard deviation of the latent vectors
|
54 |
+
|
55 |
+
#### network structures
|
56 |
+
network_G:
|
57 |
+
which_model_G: SRFlowNet
|
58 |
+
in_nc: 3
|
59 |
+
out_nc: 3
|
60 |
+
nf: 64
|
61 |
+
nb: 23
|
62 |
+
upscale: 8
|
63 |
+
train_RRDB: false
|
64 |
+
train_RRDB_delay: 0.5
|
65 |
+
|
66 |
+
flow:
|
67 |
+
K: 16
|
68 |
+
L: 4
|
69 |
+
noInitialInj: true
|
70 |
+
coupling: CondAffineSeparatedAndCond
|
71 |
+
additionalFlowNoAffine: 2
|
72 |
+
split:
|
73 |
+
enable: true
|
74 |
+
fea_up0: true
|
75 |
+
stackRRDB:
|
76 |
+
blocks: [ 1, 3, 5, 7 ]
|
77 |
+
concat: true
|
78 |
+
|
79 |
+
#### path
|
80 |
+
path:
|
81 |
+
pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
|
82 |
+
strict_load: true
|
83 |
+
resume_state: auto
|
84 |
+
|
85 |
+
#### training settings: learning rate scheme, loss
|
86 |
+
train:
|
87 |
+
manual_seed: 10
|
88 |
+
lr_G: !!float 5e-4
|
89 |
+
weight_decay_G: 0
|
90 |
+
beta1: 0.9
|
91 |
+
beta2: 0.99
|
92 |
+
lr_scheme: MultiStepLR
|
93 |
+
warmup_iter: -1 # no warm up
|
94 |
+
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
|
95 |
+
lr_gamma: 0.5
|
96 |
+
|
97 |
+
niter: 200000
|
98 |
+
val_freq: 40000
|
99 |
+
|
100 |
+
#### validation settings
|
101 |
+
val:
|
102 |
+
heats: [ 0.0, 0.5, 0.75, 1.0 ]
|
103 |
+
n_sample: 3
|
104 |
+
|
105 |
+
test:
|
106 |
+
heats: [ 0.0, 0.7, 0.8, 0.9 ]
|
107 |
+
|
108 |
+
#### logger
|
109 |
+
logger:
|
110 |
+
# Debug print_freq: 100
|
111 |
+
print_freq: 100
|
112 |
+
save_checkpoint_freq: !!float 1e3
|
models/SRFlow/code/data/LRHR_PKL_dataset.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import os
|
18 |
+
import subprocess
|
19 |
+
import torch.utils.data as data
|
20 |
+
import numpy as np
|
21 |
+
import time
|
22 |
+
import torch
|
23 |
+
|
24 |
+
import pickle
|
25 |
+
|
26 |
+
|
27 |
+
class LRHR_PKLDataset(data.Dataset):
|
28 |
+
def __init__(self, opt):
|
29 |
+
super(LRHR_PKLDataset, self).__init__()
|
30 |
+
self.opt = opt
|
31 |
+
self.crop_size = opt.get("GT_size", None)
|
32 |
+
self.scale = None
|
33 |
+
self.random_scale_list = [1]
|
34 |
+
|
35 |
+
hr_file_path = opt["dataroot_GT"]
|
36 |
+
lr_file_path = opt["dataroot_LQ"]
|
37 |
+
y_labels_file_path = opt['dataroot_y_labels']
|
38 |
+
|
39 |
+
gpu = True
|
40 |
+
augment = True
|
41 |
+
|
42 |
+
self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
|
43 |
+
self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
|
44 |
+
self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
|
45 |
+
self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
|
46 |
+
|
47 |
+
n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
|
48 |
+
|
49 |
+
t = time.time()
|
50 |
+
self.lr_images = self.load_pkls(lr_file_path, n_max)
|
51 |
+
self.hr_images = self.load_pkls(hr_file_path, n_max)
|
52 |
+
|
53 |
+
min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
|
54 |
+
max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
|
55 |
+
|
56 |
+
min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
|
57 |
+
max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
|
58 |
+
|
59 |
+
t = time.time() - t
|
60 |
+
print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
|
61 |
+
format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
|
62 |
+
print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
|
63 |
+
format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
|
64 |
+
|
65 |
+
self.gpu = gpu
|
66 |
+
self.augment = augment
|
67 |
+
|
68 |
+
self.measures = None
|
69 |
+
|
70 |
+
def load_pkls(self, path, n_max):
|
71 |
+
assert os.path.isfile(path), path
|
72 |
+
images = []
|
73 |
+
with open(path, "rb") as f:
|
74 |
+
images += pickle.load(f)
|
75 |
+
assert len(images) > 0, path
|
76 |
+
images = images[:n_max]
|
77 |
+
images = [np.transpose(image, [2, 0, 1]) for image in images]
|
78 |
+
return images
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.hr_images)
|
82 |
+
|
83 |
+
def __getitem__(self, item):
|
84 |
+
hr = self.hr_images[item]
|
85 |
+
lr = self.lr_images[item]
|
86 |
+
|
87 |
+
if self.scale == None:
|
88 |
+
self.scale = hr.shape[1] // lr.shape[1]
|
89 |
+
assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
|
90 |
+
|
91 |
+
if self.use_crop:
|
92 |
+
hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
|
93 |
+
|
94 |
+
if self.center_crop_hr_size:
|
95 |
+
hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
|
96 |
+
|
97 |
+
if self.use_flip:
|
98 |
+
hr, lr = random_flip(hr, lr)
|
99 |
+
|
100 |
+
if self.use_rot:
|
101 |
+
hr, lr = random_rotation(hr, lr)
|
102 |
+
|
103 |
+
hr = hr / 255.0
|
104 |
+
lr = lr / 255.0
|
105 |
+
|
106 |
+
if self.measures is None or np.random.random() < 0.05:
|
107 |
+
if self.measures is None:
|
108 |
+
self.measures = {}
|
109 |
+
self.measures['hr_means'] = np.mean(hr)
|
110 |
+
self.measures['hr_stds'] = np.std(hr)
|
111 |
+
self.measures['lr_means'] = np.mean(lr)
|
112 |
+
self.measures['lr_stds'] = np.std(lr)
|
113 |
+
|
114 |
+
hr = torch.Tensor(hr)
|
115 |
+
lr = torch.Tensor(lr)
|
116 |
+
|
117 |
+
# if self.gpu:
|
118 |
+
# hr = hr.cuda()
|
119 |
+
# lr = lr.cuda()
|
120 |
+
|
121 |
+
return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
|
122 |
+
|
123 |
+
def print_and_reset(self, tag):
|
124 |
+
m = self.measures
|
125 |
+
kvs = []
|
126 |
+
for k in sorted(m.keys()):
|
127 |
+
kvs.append("{}={:.2f}".format(k, m[k]))
|
128 |
+
print("[KPI] " + tag + ": " + ", ".join(kvs))
|
129 |
+
self.measures = None
|
130 |
+
|
131 |
+
|
132 |
+
def random_flip(img, seg):
|
133 |
+
random_choice = np.random.choice([True, False])
|
134 |
+
img = img if random_choice else np.flip(img, 2).copy()
|
135 |
+
seg = seg if random_choice else np.flip(seg, 2).copy()
|
136 |
+
return img, seg
|
137 |
+
|
138 |
+
|
139 |
+
def random_rotation(img, seg):
|
140 |
+
random_choice = np.random.choice([0, 1, 3])
|
141 |
+
img = np.rot90(img, random_choice, axes=(1, 2)).copy()
|
142 |
+
seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
|
143 |
+
return img, seg
|
144 |
+
|
145 |
+
|
146 |
+
def random_crop(hr, lr, size_hr, scale, random):
|
147 |
+
size_lr = size_hr // scale
|
148 |
+
|
149 |
+
size_lr_x = lr.shape[1]
|
150 |
+
size_lr_y = lr.shape[2]
|
151 |
+
|
152 |
+
start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
|
153 |
+
start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
|
154 |
+
|
155 |
+
# LR Patch
|
156 |
+
lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
|
157 |
+
|
158 |
+
# HR Patch
|
159 |
+
start_x_hr = start_x_lr * scale
|
160 |
+
start_y_hr = start_y_lr * scale
|
161 |
+
hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
|
162 |
+
|
163 |
+
return hr_patch, lr_patch
|
164 |
+
|
165 |
+
|
166 |
+
def center_crop(img, size):
|
167 |
+
assert img.shape[1] == img.shape[2], img.shape
|
168 |
+
border_double = img.shape[1] - size
|
169 |
+
assert border_double % 2 == 0, (img.shape, size)
|
170 |
+
border = border_double // 2
|
171 |
+
return img[:, border:-border, border:-border]
|
172 |
+
|
173 |
+
|
174 |
+
def center_crop_tensor(img, size):
|
175 |
+
assert img.shape[2] == img.shape[3], img.shape
|
176 |
+
border_double = img.shape[2] - size
|
177 |
+
assert border_double % 2 == 0, (img.shape, size)
|
178 |
+
border = border_double // 2
|
179 |
+
return img[:, :, border:-border, border:-border]
|
models/SRFlow/code/data/__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
'''create dataset and dataloader'''
|
18 |
+
import logging
|
19 |
+
import torch
|
20 |
+
import torch.utils.data
|
21 |
+
|
22 |
+
|
23 |
+
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
24 |
+
phase = dataset_opt.get('phase', 'test')
|
25 |
+
if phase == 'train':
|
26 |
+
gpu_ids = opt.get('gpu_ids', None)
|
27 |
+
gpu_ids = gpu_ids if gpu_ids else []
|
28 |
+
num_workers = dataset_opt['n_workers'] * len(gpu_ids)
|
29 |
+
batch_size = dataset_opt['batch_size']
|
30 |
+
shuffle = True
|
31 |
+
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
32 |
+
num_workers=num_workers, sampler=sampler, drop_last=True,
|
33 |
+
pin_memory=False)
|
34 |
+
else:
|
35 |
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
36 |
+
pin_memory=True)
|
37 |
+
|
38 |
+
|
39 |
+
def create_dataset(dataset_opt):
|
40 |
+
print(dataset_opt)
|
41 |
+
mode = dataset_opt['mode']
|
42 |
+
if mode == 'LRHR_PKL':
|
43 |
+
from data.LRHR_PKL_dataset import LRHR_PKLDataset as D
|
44 |
+
else:
|
45 |
+
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
46 |
+
dataset = D(dataset_opt)
|
47 |
+
|
48 |
+
logger = logging.getLogger('base')
|
49 |
+
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
50 |
+
dataset_opt['name']))
|
51 |
+
return dataset
|
models/SRFlow/code/demo_on_pretrained.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/SRFlow/code/imresize.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/fatheral/matlab_imresize
|
2 |
+
#
|
3 |
+
# MIT License
|
4 |
+
#
|
5 |
+
# Copyright (c) 2020 Alex
|
6 |
+
#
|
7 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
# of this software and associated documentation files (the "Software"), to deal
|
9 |
+
# in the Software without restriction, including without limitation the rights
|
10 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
# copies of the Software, and to permit persons to whom the Software is
|
12 |
+
# furnished to do so, subject to the following conditions:
|
13 |
+
#
|
14 |
+
# The above copyright notice and this permission notice shall be included in all
|
15 |
+
# copies or substantial portions of the Software.
|
16 |
+
#
|
17 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
# SOFTWARE.
|
24 |
+
|
25 |
+
|
26 |
+
from __future__ import print_function
|
27 |
+
import numpy as np
|
28 |
+
from math import ceil, floor
|
29 |
+
|
30 |
+
|
31 |
+
def deriveSizeFromScale(img_shape, scale):
|
32 |
+
output_shape = []
|
33 |
+
for k in range(2):
|
34 |
+
output_shape.append(int(ceil(scale[k] * img_shape[k])))
|
35 |
+
return output_shape
|
36 |
+
|
37 |
+
|
38 |
+
def deriveScaleFromSize(img_shape_in, img_shape_out):
|
39 |
+
scale = []
|
40 |
+
for k in range(2):
|
41 |
+
scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
|
42 |
+
return scale
|
43 |
+
|
44 |
+
|
45 |
+
def triangle(x):
|
46 |
+
x = np.array(x).astype(np.float64)
|
47 |
+
lessthanzero = np.logical_and((x >= -1), x < 0)
|
48 |
+
greaterthanzero = np.logical_and((x <= 1), x >= 0)
|
49 |
+
f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
|
50 |
+
return f
|
51 |
+
|
52 |
+
|
53 |
+
def cubic(x):
|
54 |
+
x = np.array(x).astype(np.float64)
|
55 |
+
absx = np.absolute(x)
|
56 |
+
absx2 = np.multiply(absx, absx)
|
57 |
+
absx3 = np.multiply(absx2, absx)
|
58 |
+
f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
|
59 |
+
(1 < absx) & (absx <= 2))
|
60 |
+
return f
|
61 |
+
|
62 |
+
|
63 |
+
def contributions(in_length, out_length, scale, kernel, k_width):
|
64 |
+
if scale < 1:
|
65 |
+
h = lambda x: scale * kernel(scale * x)
|
66 |
+
kernel_width = 1.0 * k_width / scale
|
67 |
+
else:
|
68 |
+
h = kernel
|
69 |
+
kernel_width = k_width
|
70 |
+
x = np.arange(1, out_length + 1).astype(np.float64)
|
71 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
72 |
+
left = np.floor(u - kernel_width / 2)
|
73 |
+
P = int(ceil(kernel_width)) + 2
|
74 |
+
ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
|
75 |
+
indices = ind.astype(np.int32)
|
76 |
+
weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
|
77 |
+
weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
|
78 |
+
aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
|
79 |
+
indices = aux[np.mod(indices, aux.size)]
|
80 |
+
ind2store = np.nonzero(np.any(weights, axis=0))
|
81 |
+
weights = weights[:, ind2store]
|
82 |
+
indices = indices[:, ind2store]
|
83 |
+
return weights, indices
|
84 |
+
|
85 |
+
|
86 |
+
def imresizemex(inimg, weights, indices, dim):
|
87 |
+
in_shape = inimg.shape
|
88 |
+
w_shape = weights.shape
|
89 |
+
out_shape = list(in_shape)
|
90 |
+
out_shape[dim] = w_shape[0]
|
91 |
+
outimg = np.zeros(out_shape)
|
92 |
+
if dim == 0:
|
93 |
+
for i_img in range(in_shape[1]):
|
94 |
+
for i_w in range(w_shape[0]):
|
95 |
+
w = weights[i_w, :]
|
96 |
+
ind = indices[i_w, :]
|
97 |
+
im_slice = inimg[ind, i_img].astype(np.float64)
|
98 |
+
outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
99 |
+
elif dim == 1:
|
100 |
+
for i_img in range(in_shape[0]):
|
101 |
+
for i_w in range(w_shape[0]):
|
102 |
+
w = weights[i_w, :]
|
103 |
+
ind = indices[i_w, :]
|
104 |
+
im_slice = inimg[i_img, ind].astype(np.float64)
|
105 |
+
outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
106 |
+
if inimg.dtype == np.uint8:
|
107 |
+
outimg = np.clip(outimg, 0, 255)
|
108 |
+
return np.around(outimg).astype(np.uint8)
|
109 |
+
else:
|
110 |
+
return outimg
|
111 |
+
|
112 |
+
|
113 |
+
def imresizevec(inimg, weights, indices, dim):
|
114 |
+
wshape = weights.shape
|
115 |
+
if dim == 0:
|
116 |
+
weights = weights.reshape((wshape[0], wshape[2], 1, 1))
|
117 |
+
outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
|
118 |
+
elif dim == 1:
|
119 |
+
weights = weights.reshape((1, wshape[0], wshape[2], 1))
|
120 |
+
outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
|
121 |
+
if inimg.dtype == np.uint8:
|
122 |
+
outimg = np.clip(outimg, 0, 255)
|
123 |
+
return np.around(outimg).astype(np.uint8)
|
124 |
+
else:
|
125 |
+
return outimg
|
126 |
+
|
127 |
+
|
128 |
+
def resizeAlongDim(A, dim, weights, indices, mode="vec"):
|
129 |
+
if mode == "org":
|
130 |
+
out = imresizemex(A, weights, indices, dim)
|
131 |
+
else:
|
132 |
+
out = imresizevec(A, weights, indices, dim)
|
133 |
+
return out
|
134 |
+
|
135 |
+
|
136 |
+
def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
|
137 |
+
if method is 'bicubic':
|
138 |
+
kernel = cubic
|
139 |
+
elif method is 'bilinear':
|
140 |
+
kernel = triangle
|
141 |
+
else:
|
142 |
+
print('Error: Unidentified method supplied')
|
143 |
+
|
144 |
+
kernel_width = 4.0
|
145 |
+
# Fill scale and output_size
|
146 |
+
if scalar_scale is not None:
|
147 |
+
scalar_scale = float(scalar_scale)
|
148 |
+
scale = [scalar_scale, scalar_scale]
|
149 |
+
output_size = deriveSizeFromScale(I.shape, scale)
|
150 |
+
elif output_shape is not None:
|
151 |
+
scale = deriveScaleFromSize(I.shape, output_shape)
|
152 |
+
output_size = list(output_shape)
|
153 |
+
else:
|
154 |
+
print('Error: scalar_scale OR output_shape should be defined!')
|
155 |
+
return
|
156 |
+
scale_np = np.array(scale)
|
157 |
+
order = np.argsort(scale_np)
|
158 |
+
weights = []
|
159 |
+
indices = []
|
160 |
+
for k in range(2):
|
161 |
+
w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
|
162 |
+
weights.append(w)
|
163 |
+
indices.append(ind)
|
164 |
+
B = np.copy(I)
|
165 |
+
flag2D = False
|
166 |
+
if B.ndim == 2:
|
167 |
+
B = np.expand_dims(B, axis=2)
|
168 |
+
flag2D = True
|
169 |
+
for k in range(2):
|
170 |
+
dim = order[k]
|
171 |
+
B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
|
172 |
+
if flag2D:
|
173 |
+
B = np.squeeze(B, axis=2)
|
174 |
+
return B
|
175 |
+
|
176 |
+
|
177 |
+
def convertDouble2Byte(I):
|
178 |
+
B = np.clip(I, 0.0, 1.0)
|
179 |
+
B = 255 * B
|
180 |
+
return np.around(B).astype(np.uint8)
|
models/SRFlow/code/models/SRFlow_model.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import logging
|
18 |
+
from collections import OrderedDict
|
19 |
+
from utils.util import get_resume_paths, opt_get
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
24 |
+
import models.networks as networks
|
25 |
+
import models.lr_scheduler as lr_scheduler
|
26 |
+
from .base_model import BaseModel
|
27 |
+
|
28 |
+
logger = logging.getLogger('base')
|
29 |
+
|
30 |
+
|
31 |
+
class SRFlowModel(BaseModel):
|
32 |
+
def __init__(self, opt, step):
|
33 |
+
super(SRFlowModel, self).__init__(opt)
|
34 |
+
self.opt = opt
|
35 |
+
|
36 |
+
self.heats = opt['val']['heats']
|
37 |
+
self.n_sample = opt['val']['n_sample']
|
38 |
+
self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size'])
|
39 |
+
self.hr_size = 256 if self.hr_size is None else self.hr_size
|
40 |
+
self.lr_size = self.hr_size // opt['scale']
|
41 |
+
|
42 |
+
if opt['dist']:
|
43 |
+
self.rank = torch.distributed.get_rank()
|
44 |
+
else:
|
45 |
+
self.rank = -1 # non dist training
|
46 |
+
train_opt = opt['train']
|
47 |
+
|
48 |
+
# define network and load pretrained models
|
49 |
+
self.netG = networks.define_Flow(opt, step).to(self.device)
|
50 |
+
if opt['dist']:
|
51 |
+
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
52 |
+
else:
|
53 |
+
self.netG = DataParallel(self.netG)
|
54 |
+
# print network
|
55 |
+
self.print_network()
|
56 |
+
|
57 |
+
if opt_get(opt, ['path', 'resume_state'], 1) is not None:
|
58 |
+
self.load()
|
59 |
+
else:
|
60 |
+
print("WARNING: skipping initial loading, due to resume_state None")
|
61 |
+
|
62 |
+
if self.is_train:
|
63 |
+
self.netG.train()
|
64 |
+
|
65 |
+
self.init_optimizer_and_scheduler(train_opt)
|
66 |
+
self.log_dict = OrderedDict()
|
67 |
+
|
68 |
+
def to(self, device):
|
69 |
+
self.device = device
|
70 |
+
self.netG.to(device)
|
71 |
+
|
72 |
+
def init_optimizer_and_scheduler(self, train_opt):
|
73 |
+
# optimizers
|
74 |
+
self.optimizers = []
|
75 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
76 |
+
optim_params_RRDB = []
|
77 |
+
optim_params_other = []
|
78 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
79 |
+
print(k, v.requires_grad)
|
80 |
+
if v.requires_grad:
|
81 |
+
if '.RRDB.' in k:
|
82 |
+
optim_params_RRDB.append(v)
|
83 |
+
print('opt', k)
|
84 |
+
else:
|
85 |
+
optim_params_other.append(v)
|
86 |
+
if self.rank <= 0:
|
87 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
88 |
+
|
89 |
+
print('rrdb params', len(optim_params_RRDB))
|
90 |
+
|
91 |
+
self.optimizer_G = torch.optim.Adam(
|
92 |
+
[
|
93 |
+
{"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'],
|
94 |
+
'beta2': train_opt['beta2'], 'weight_decay': wd_G},
|
95 |
+
{"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
|
96 |
+
'beta1': train_opt['beta1'],
|
97 |
+
'beta2': train_opt['beta2'], 'weight_decay': wd_G}
|
98 |
+
],
|
99 |
+
)
|
100 |
+
|
101 |
+
self.optimizers.append(self.optimizer_G)
|
102 |
+
# schedulers
|
103 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
104 |
+
for optimizer in self.optimizers:
|
105 |
+
self.schedulers.append(
|
106 |
+
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
107 |
+
restarts=train_opt['restarts'],
|
108 |
+
weights=train_opt['restart_weights'],
|
109 |
+
gamma=train_opt['lr_gamma'],
|
110 |
+
clear_state=train_opt['clear_state'],
|
111 |
+
lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
|
112 |
+
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
113 |
+
for optimizer in self.optimizers:
|
114 |
+
self.schedulers.append(
|
115 |
+
lr_scheduler.CosineAnnealingLR_Restart(
|
116 |
+
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
117 |
+
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
118 |
+
else:
|
119 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
120 |
+
|
121 |
+
def add_optimizer_and_scheduler_RRDB(self, train_opt):
|
122 |
+
# optimizers
|
123 |
+
assert len(self.optimizers) == 1, self.optimizers
|
124 |
+
assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1]
|
125 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
126 |
+
if v.requires_grad:
|
127 |
+
if '.RRDB.' in k:
|
128 |
+
self.optimizer_G.param_groups[1]['params'].append(v)
|
129 |
+
assert len(self.optimizer_G.param_groups[1]['params']) > 0
|
130 |
+
|
131 |
+
def feed_data(self, data, need_GT=True):
|
132 |
+
self.var_L = data['LQ'].to(self.device) # LQ
|
133 |
+
if need_GT:
|
134 |
+
self.real_H = data['GT'].to(self.device) # GT
|
135 |
+
|
136 |
+
def optimize_parameters(self, step):
|
137 |
+
|
138 |
+
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
139 |
+
if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
|
140 |
+
and not self.netG.module.RRDB_training:
|
141 |
+
if self.netG.module.set_rrdb_training(True):
|
142 |
+
self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
|
143 |
+
|
144 |
+
# self.print_rrdb_state()
|
145 |
+
|
146 |
+
self.netG.train()
|
147 |
+
self.log_dict = OrderedDict()
|
148 |
+
self.optimizer_G.zero_grad()
|
149 |
+
|
150 |
+
losses = {}
|
151 |
+
weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
|
152 |
+
weight_fl = 1 if weight_fl is None else weight_fl
|
153 |
+
if weight_fl > 0:
|
154 |
+
z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
|
155 |
+
nll_loss = torch.mean(nll)
|
156 |
+
losses['nll_loss'] = nll_loss * weight_fl
|
157 |
+
|
158 |
+
weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
|
159 |
+
if weight_l1 > 0:
|
160 |
+
z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
|
161 |
+
sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True)
|
162 |
+
l1_loss = (sr - self.real_H).abs().mean()
|
163 |
+
losses['l1_loss'] = l1_loss * weight_l1
|
164 |
+
|
165 |
+
total_loss = sum(losses.values())
|
166 |
+
total_loss.backward()
|
167 |
+
self.optimizer_G.step()
|
168 |
+
|
169 |
+
mean = total_loss.item()
|
170 |
+
return mean
|
171 |
+
|
172 |
+
def print_rrdb_state(self):
|
173 |
+
for name, param in self.netG.module.named_parameters():
|
174 |
+
if "RRDB.conv_first.weight" in name:
|
175 |
+
print(name, param.requires_grad, param.data.abs().sum())
|
176 |
+
print('params', [len(p['params']) for p in self.optimizer_G.param_groups])
|
177 |
+
|
178 |
+
def test(self):
|
179 |
+
self.netG.eval()
|
180 |
+
self.fake_H = {}
|
181 |
+
for heat in self.heats:
|
182 |
+
for i in range(self.n_sample):
|
183 |
+
z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
|
184 |
+
with torch.no_grad():
|
185 |
+
self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True)
|
186 |
+
with torch.no_grad():
|
187 |
+
_, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
|
188 |
+
self.netG.train()
|
189 |
+
return nll.mean().item()
|
190 |
+
|
191 |
+
def get_encode_nll(self, lq, gt):
|
192 |
+
self.netG.eval()
|
193 |
+
with torch.no_grad():
|
194 |
+
_, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
|
195 |
+
self.netG.train()
|
196 |
+
return nll.mean().item()
|
197 |
+
|
198 |
+
def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
|
199 |
+
return self.get_sr_with_z(lq, heat, seed, z, epses)[0]
|
200 |
+
|
201 |
+
def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
|
202 |
+
self.netG.eval()
|
203 |
+
with torch.no_grad():
|
204 |
+
z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
|
205 |
+
self.netG.train()
|
206 |
+
return z
|
207 |
+
|
208 |
+
def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
|
209 |
+
self.netG.eval()
|
210 |
+
with torch.no_grad():
|
211 |
+
z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
|
212 |
+
self.netG.train()
|
213 |
+
return z, nll
|
214 |
+
|
215 |
+
def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
|
216 |
+
self.netG.eval()
|
217 |
+
|
218 |
+
z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z
|
219 |
+
|
220 |
+
with torch.no_grad():
|
221 |
+
sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses)
|
222 |
+
self.netG.train()
|
223 |
+
return sr, z
|
224 |
+
|
225 |
+
def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
|
226 |
+
if seed: torch.manual_seed(seed)
|
227 |
+
if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
|
228 |
+
C = self.netG.module.flowUpsamplerNet.C
|
229 |
+
H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH)
|
230 |
+
W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW)
|
231 |
+
z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros(
|
232 |
+
(batch_size, C, H, W))
|
233 |
+
else:
|
234 |
+
L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
|
235 |
+
fac = 2 ** (L - 3)
|
236 |
+
z_size = int(self.lr_size // (2 ** (L - 3)))
|
237 |
+
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
|
238 |
+
return z
|
239 |
+
|
240 |
+
def get_current_log(self):
|
241 |
+
return self.log_dict
|
242 |
+
|
243 |
+
def get_current_visuals(self, need_GT=True):
|
244 |
+
out_dict = OrderedDict()
|
245 |
+
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
246 |
+
for heat in self.heats:
|
247 |
+
for i in range(self.n_sample):
|
248 |
+
out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu()
|
249 |
+
if need_GT:
|
250 |
+
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
251 |
+
return out_dict
|
252 |
+
|
253 |
+
def print_network(self):
|
254 |
+
s, n = self.get_network_description(self.netG)
|
255 |
+
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
256 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
257 |
+
self.netG.module.__class__.__name__)
|
258 |
+
else:
|
259 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
260 |
+
if self.rank <= 0:
|
261 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
262 |
+
logger.info(s)
|
263 |
+
|
264 |
+
def load(self):
|
265 |
+
_, get_resume_model_path = get_resume_paths(self.opt)
|
266 |
+
if get_resume_model_path is not None:
|
267 |
+
self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None)
|
268 |
+
return
|
269 |
+
|
270 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
271 |
+
load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB'
|
272 |
+
if load_path_G is not None:
|
273 |
+
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
274 |
+
self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True),
|
275 |
+
submodule=load_submodule)
|
276 |
+
|
277 |
+
def save(self, iter_label):
|
278 |
+
self.save_network(self.netG, 'G', iter_label)
|
models/SRFlow/code/models/SR_model.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import logging
|
18 |
+
from collections import OrderedDict
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
23 |
+
import models.networks as networks
|
24 |
+
import models.lr_scheduler as lr_scheduler
|
25 |
+
from utils.util import opt_get
|
26 |
+
from .base_model import BaseModel
|
27 |
+
from models.modules.loss import CharbonnierLoss
|
28 |
+
|
29 |
+
logger = logging.getLogger('base')
|
30 |
+
|
31 |
+
|
32 |
+
class SRModel(BaseModel):
|
33 |
+
def __init__(self, opt, step):
|
34 |
+
super(SRModel, self).__init__(opt)
|
35 |
+
|
36 |
+
self.step = step
|
37 |
+
|
38 |
+
if opt['dist']:
|
39 |
+
self.rank = torch.distributed.get_rank()
|
40 |
+
else:
|
41 |
+
self.rank = -1 # non dist training
|
42 |
+
train_opt = opt['train']
|
43 |
+
|
44 |
+
# define network and load pretrained_models models
|
45 |
+
self.netG = networks.define_G(opt).to(self.device)
|
46 |
+
if opt['dist']:
|
47 |
+
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
48 |
+
else:
|
49 |
+
self.netG = DataParallel(self.netG)
|
50 |
+
# print network
|
51 |
+
self.print_network()
|
52 |
+
self.load()
|
53 |
+
|
54 |
+
if self.is_train:
|
55 |
+
self.netG.train()
|
56 |
+
|
57 |
+
# loss
|
58 |
+
loss_type = train_opt['pixel_criterion']
|
59 |
+
if loss_type == 'l1':
|
60 |
+
self.cri_pix = nn.L1Loss().to(self.device)
|
61 |
+
elif loss_type == 'l2':
|
62 |
+
self.cri_pix = nn.MSELoss().to(self.device)
|
63 |
+
elif loss_type == 'cb':
|
64 |
+
self.cri_pix = CharbonnierLoss().to(self.device)
|
65 |
+
else:
|
66 |
+
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
67 |
+
self.l_pix_w = train_opt['pixel_weight']
|
68 |
+
|
69 |
+
# optimizers
|
70 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
71 |
+
optim_params = []
|
72 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
73 |
+
if v.requires_grad:
|
74 |
+
optim_params.append(v)
|
75 |
+
else:
|
76 |
+
if self.rank <= 0:
|
77 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
78 |
+
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
79 |
+
weight_decay=wd_G,
|
80 |
+
betas=(train_opt['beta1'], train_opt['beta2']))
|
81 |
+
self.optimizers.append(self.optimizer_G)
|
82 |
+
|
83 |
+
# schedulers
|
84 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
85 |
+
for optimizer in self.optimizers:
|
86 |
+
self.schedulers.append(
|
87 |
+
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
88 |
+
restarts=train_opt['restarts'],
|
89 |
+
weights=train_opt['restart_weights'],
|
90 |
+
gamma=train_opt['lr_gamma'],
|
91 |
+
clear_state=train_opt['clear_state']))
|
92 |
+
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
93 |
+
for optimizer in self.optimizers:
|
94 |
+
self.schedulers.append(
|
95 |
+
lr_scheduler.CosineAnnealingLR_Restart(
|
96 |
+
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
97 |
+
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
98 |
+
else:
|
99 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
100 |
+
|
101 |
+
self.log_dict = OrderedDict()
|
102 |
+
|
103 |
+
def feed_data(self, data, need_GT=True):
|
104 |
+
self.var_L = data['LQ'].to(self.device) # LQ
|
105 |
+
if need_GT:
|
106 |
+
self.real_H = data['GT'].to(self.device) # GT
|
107 |
+
|
108 |
+
def to(self, device):
|
109 |
+
self.device = device
|
110 |
+
self.netG.to(device)
|
111 |
+
|
112 |
+
def optimize_parameters(self, step):
|
113 |
+
def getEnv(name): import os; return True if name in os.environ.keys() else False
|
114 |
+
|
115 |
+
if getEnv("DEBUG_FEED_IMAGES"):
|
116 |
+
import imageio
|
117 |
+
import random
|
118 |
+
i = random.randint(0, 10000)
|
119 |
+
label = self.var_L.cpu().numpy()[0].transpose([1, 2, 0])
|
120 |
+
print("var_L", label.min(), label.max(), label.shape)
|
121 |
+
imageio.imwrite("/tmp/{}_l.png".format(i), label)
|
122 |
+
image = self.real_H.cpu().numpy()[0].transpose([1, 2, 0])
|
123 |
+
print("self.real_H", image.min(), image.max(), image.shape)
|
124 |
+
imageio.imwrite("/tmp/{}_gt.png".format(i), image)
|
125 |
+
self.optimizer_G.zero_grad()
|
126 |
+
self.fake_H = self.netG(self.var_L)
|
127 |
+
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H.to(self.fake_H.device))
|
128 |
+
l_pix.backward()
|
129 |
+
self.optimizer_G.step()
|
130 |
+
|
131 |
+
# set log
|
132 |
+
self.log_dict['l_pix'] = l_pix.item()
|
133 |
+
|
134 |
+
def test(self):
|
135 |
+
self.netG.eval()
|
136 |
+
with torch.no_grad():
|
137 |
+
self.fake_H = self.netG(self.var_L)
|
138 |
+
self.netG.train()
|
139 |
+
|
140 |
+
def get_encode_nll(self, lq, gt):
|
141 |
+
return torch.ones(1) * 1e14
|
142 |
+
|
143 |
+
def get_sr(self, lq, heat=None, seed=None):
|
144 |
+
self.netG.eval()
|
145 |
+
sr = self.netG(lq)
|
146 |
+
self.netG.train()
|
147 |
+
return sr
|
148 |
+
|
149 |
+
def test_x8(self):
|
150 |
+
# from https://github.com/thstkdgus35/EDSR-PyTorch
|
151 |
+
self.netG.eval()
|
152 |
+
|
153 |
+
def _transform(v, op):
|
154 |
+
# if self.precision != 'single': v = v.float()
|
155 |
+
v2np = v.data.cpu().numpy()
|
156 |
+
if op == 'v':
|
157 |
+
tfnp = v2np[:, :, :, ::-1].copy()
|
158 |
+
elif op == 'h':
|
159 |
+
tfnp = v2np[:, :, ::-1, :].copy()
|
160 |
+
elif op == 't':
|
161 |
+
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
162 |
+
|
163 |
+
ret = torch.Tensor(tfnp).to(self.device)
|
164 |
+
# if self.precision == 'half': ret = ret.half()
|
165 |
+
|
166 |
+
return ret
|
167 |
+
|
168 |
+
lr_list = [self.var_L]
|
169 |
+
for tf in 'v', 'h', 't':
|
170 |
+
lr_list.extend([_transform(t, tf) for t in lr_list])
|
171 |
+
with torch.no_grad():
|
172 |
+
sr_list = [self.netG(aug) for aug in lr_list]
|
173 |
+
for i in range(len(sr_list)):
|
174 |
+
if i > 3:
|
175 |
+
sr_list[i] = _transform(sr_list[i], 't')
|
176 |
+
if i % 4 > 1:
|
177 |
+
sr_list[i] = _transform(sr_list[i], 'h')
|
178 |
+
if (i % 4) % 2 == 1:
|
179 |
+
sr_list[i] = _transform(sr_list[i], 'v')
|
180 |
+
|
181 |
+
output_cat = torch.cat(sr_list, dim=0)
|
182 |
+
self.fake_H = output_cat.mean(dim=0, keepdim=True)
|
183 |
+
self.netG.train()
|
184 |
+
|
185 |
+
def get_current_log(self):
|
186 |
+
return self.log_dict
|
187 |
+
|
188 |
+
def get_current_visuals(self, need_GT=True):
|
189 |
+
out_dict = OrderedDict()
|
190 |
+
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
191 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
192 |
+
if need_GT:
|
193 |
+
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
194 |
+
return out_dict
|
195 |
+
|
196 |
+
def print_network(self):
|
197 |
+
s, n = self.get_network_description(self.netG)
|
198 |
+
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
199 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
200 |
+
self.netG.module.__class__.__name__)
|
201 |
+
else:
|
202 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
203 |
+
if self.rank <= 0:
|
204 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
205 |
+
logger.info(s)
|
206 |
+
|
207 |
+
def load(self):
|
208 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
209 |
+
if load_path_G is not None:
|
210 |
+
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
211 |
+
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
212 |
+
|
213 |
+
def save(self, iter_label):
|
214 |
+
self.save_network(self.netG, 'G', iter_label)
|
215 |
+
|
216 |
+
def get_encode_z_and_nll(self, *args, **kwargs):
|
217 |
+
return [], torch.zeros(1)
|
models/SRFlow/code/models/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
try:
|
6 |
+
import local_config
|
7 |
+
except:
|
8 |
+
local_config = None
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger('base')
|
12 |
+
|
13 |
+
|
14 |
+
def find_model_using_name(model_name):
|
15 |
+
# Given the option --model [modelname],
|
16 |
+
# the file "models/modelname_model.py"
|
17 |
+
# will be imported.
|
18 |
+
model_filename = "models." + model_name + "_model"
|
19 |
+
modellib = importlib.import_module(model_filename)
|
20 |
+
|
21 |
+
# In the file, the class called ModelNameModel() will
|
22 |
+
# be instantiated. It has to be a subclass of torch.nn.Module,
|
23 |
+
# and it is case-insensitive.
|
24 |
+
model = None
|
25 |
+
target_model_name = model_name.replace('_', '') + 'Model'
|
26 |
+
for name, cls in modellib.__dict__.items():
|
27 |
+
if name.lower() == target_model_name.lower():
|
28 |
+
model = cls
|
29 |
+
|
30 |
+
if model is None:
|
31 |
+
print(
|
32 |
+
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
|
33 |
+
model_filename, target_model_name))
|
34 |
+
exit(0)
|
35 |
+
|
36 |
+
return model
|
37 |
+
|
38 |
+
|
39 |
+
def create_model(opt, step=0, **opt_kwargs):
|
40 |
+
if local_config is not None:
|
41 |
+
opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth'))
|
42 |
+
|
43 |
+
for k, v in opt_kwargs.items():
|
44 |
+
opt[k] = v
|
45 |
+
|
46 |
+
model = opt['model']
|
47 |
+
|
48 |
+
M = find_model_using_name(model)
|
49 |
+
|
50 |
+
m = M(opt, step)
|
51 |
+
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
52 |
+
return m
|
models/SRFlow/code/models/base_model.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import os
|
18 |
+
from collections import OrderedDict
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn.parallel import DistributedDataParallel
|
22 |
+
import natsort
|
23 |
+
import glob
|
24 |
+
|
25 |
+
|
26 |
+
class BaseModel():
|
27 |
+
def __init__(self, opt):
|
28 |
+
self.opt = opt
|
29 |
+
self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu')
|
30 |
+
self.is_train = opt['is_train']
|
31 |
+
self.schedulers = []
|
32 |
+
self.optimizers = []
|
33 |
+
|
34 |
+
def feed_data(self, data):
|
35 |
+
pass
|
36 |
+
|
37 |
+
def optimize_parameters(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def get_current_visuals(self):
|
41 |
+
pass
|
42 |
+
|
43 |
+
def get_current_losses(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def print_network(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def save(self, label):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def load(self):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def _set_lr(self, lr_groups_l):
|
56 |
+
''' set learning rate for warmup,
|
57 |
+
lr_groups_l: list for lr_groups. each for a optimizer'''
|
58 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
59 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
60 |
+
param_group['lr'] = lr
|
61 |
+
|
62 |
+
def _get_init_lr(self):
|
63 |
+
# get the initial lr, which is set by the scheduler
|
64 |
+
init_lr_groups_l = []
|
65 |
+
for optimizer in self.optimizers:
|
66 |
+
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
67 |
+
return init_lr_groups_l
|
68 |
+
|
69 |
+
def update_learning_rate(self, cur_iter, warmup_iter=-1):
|
70 |
+
for scheduler in self.schedulers:
|
71 |
+
scheduler.step()
|
72 |
+
#### set up warm up learning rate
|
73 |
+
if cur_iter < warmup_iter:
|
74 |
+
# get initial lr for each group
|
75 |
+
init_lr_g_l = self._get_init_lr()
|
76 |
+
# modify warming-up learning rates
|
77 |
+
warm_up_lr_l = []
|
78 |
+
for init_lr_g in init_lr_g_l:
|
79 |
+
warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
|
80 |
+
# set learning rate
|
81 |
+
self._set_lr(warm_up_lr_l)
|
82 |
+
|
83 |
+
def get_current_learning_rate(self):
|
84 |
+
# return self.schedulers[0].get_lr()[0]
|
85 |
+
return self.optimizers[0].param_groups[0]['lr']
|
86 |
+
|
87 |
+
def get_network_description(self, network):
|
88 |
+
'''Get the string and total parameters of the network'''
|
89 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
90 |
+
network = network.module
|
91 |
+
s = str(network)
|
92 |
+
n = sum(map(lambda x: x.numel(), network.parameters()))
|
93 |
+
return s, n
|
94 |
+
|
95 |
+
def save_network(self, network, network_label, iter_label):
|
96 |
+
paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))),
|
97 |
+
reverse=True)
|
98 |
+
paths = [p for p in paths if
|
99 |
+
"latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])]
|
100 |
+
if len(paths) > 2:
|
101 |
+
for path in paths[2:]:
|
102 |
+
os.remove(path)
|
103 |
+
save_filename = '{}_{}.pth'.format(iter_label, network_label)
|
104 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
105 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
106 |
+
network = network.module
|
107 |
+
state_dict = network.state_dict()
|
108 |
+
for key, param in state_dict.items():
|
109 |
+
state_dict[key] = param.cpu()
|
110 |
+
torch.save(state_dict, save_path)
|
111 |
+
|
112 |
+
def load_network(self, load_path, network, strict=True, submodule=None):
|
113 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
114 |
+
network = network.module
|
115 |
+
if not (submodule is None or submodule.lower() == 'none'.lower()):
|
116 |
+
network = network.__getattr__(submodule)
|
117 |
+
load_net = torch.load(load_path)
|
118 |
+
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
119 |
+
for k, v in load_net.items():
|
120 |
+
if k.startswith('module.'):
|
121 |
+
load_net_clean[k[7:]] = v
|
122 |
+
else:
|
123 |
+
load_net_clean[k] = v
|
124 |
+
network.load_state_dict(load_net_clean, strict=strict)
|
125 |
+
|
126 |
+
def save_training_state(self, epoch, iter_step):
|
127 |
+
'''Saves training state during training, which will be used for resuming'''
|
128 |
+
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
|
129 |
+
for s in self.schedulers:
|
130 |
+
state['schedulers'].append(s.state_dict())
|
131 |
+
for o in self.optimizers:
|
132 |
+
state['optimizers'].append(o.state_dict())
|
133 |
+
save_filename = '{}.state'.format(iter_step)
|
134 |
+
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
135 |
+
|
136 |
+
paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")),
|
137 |
+
reverse=True)
|
138 |
+
paths = [p for p in paths if "latest_" not in p]
|
139 |
+
if len(paths) > 2:
|
140 |
+
for path in paths[2:]:
|
141 |
+
os.remove(path)
|
142 |
+
|
143 |
+
torch.save(state, save_path)
|
144 |
+
|
145 |
+
def resume_training(self, resume_state):
|
146 |
+
'''Resume the optimizers and schedulers for training'''
|
147 |
+
resume_optimizers = resume_state['optimizers']
|
148 |
+
resume_schedulers = resume_state['schedulers']
|
149 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
150 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
151 |
+
for i, o in enumerate(resume_optimizers):
|
152 |
+
self.optimizers[i].load_state_dict(o)
|
153 |
+
for i, s in enumerate(resume_schedulers):
|
154 |
+
self.schedulers[i].load_state_dict(s)
|
models/SRFlow/code/models/lr_scheduler.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import math
|
18 |
+
from collections import Counter
|
19 |
+
from collections import defaultdict
|
20 |
+
import torch
|
21 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
22 |
+
|
23 |
+
|
24 |
+
class MultiStepLR_Restart(_LRScheduler):
|
25 |
+
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
26 |
+
clear_state=False, last_epoch=-1, lr_steps_invese=None):
|
27 |
+
assert lr_steps_invese is not None, "Use empty list"
|
28 |
+
self.milestones = Counter(milestones)
|
29 |
+
self.lr_steps_inverse = Counter(lr_steps_invese)
|
30 |
+
self.gamma = gamma
|
31 |
+
self.clear_state = clear_state
|
32 |
+
self.restarts = restarts if restarts else [0]
|
33 |
+
self.restart_weights = weights if weights else [1]
|
34 |
+
assert len(self.restarts) == len(
|
35 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
36 |
+
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
37 |
+
|
38 |
+
def get_lr(self):
|
39 |
+
if self.last_epoch in self.restarts:
|
40 |
+
if self.clear_state:
|
41 |
+
self.optimizer.state = defaultdict(dict)
|
42 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
43 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
44 |
+
if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse:
|
45 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
46 |
+
return [
|
47 |
+
group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) *
|
48 |
+
(self.gamma ** (-self.lr_steps_inverse[self.last_epoch]))
|
49 |
+
for group in self.optimizer.param_groups
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
class CosineAnnealingLR_Restart(_LRScheduler):
|
54 |
+
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
55 |
+
self.T_period = T_period
|
56 |
+
self.T_max = self.T_period[0] # current T period
|
57 |
+
self.eta_min = eta_min
|
58 |
+
self.restarts = restarts if restarts else [0]
|
59 |
+
self.restart_weights = weights if weights else [1]
|
60 |
+
self.last_restart = 0
|
61 |
+
assert len(self.restarts) == len(
|
62 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
63 |
+
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
64 |
+
|
65 |
+
def get_lr(self):
|
66 |
+
if self.last_epoch == 0:
|
67 |
+
return self.base_lrs
|
68 |
+
elif self.last_epoch in self.restarts:
|
69 |
+
self.last_restart = self.last_epoch
|
70 |
+
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
|
71 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
72 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
73 |
+
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
74 |
+
return [
|
75 |
+
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
76 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
77 |
+
]
|
78 |
+
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
|
79 |
+
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
|
80 |
+
(group['lr'] - self.eta_min) + self.eta_min
|
81 |
+
for group in self.optimizer.param_groups]
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
|
86 |
+
betas=(0.9, 0.99))
|
87 |
+
##############################
|
88 |
+
# MultiStepLR_Restart
|
89 |
+
##############################
|
90 |
+
## Original
|
91 |
+
lr_steps = [200000, 400000, 600000, 800000]
|
92 |
+
restarts = None
|
93 |
+
restart_weights = None
|
94 |
+
|
95 |
+
## two
|
96 |
+
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
|
97 |
+
restarts = [500000]
|
98 |
+
restart_weights = [1]
|
99 |
+
|
100 |
+
## four
|
101 |
+
lr_steps = [
|
102 |
+
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
|
103 |
+
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
|
104 |
+
]
|
105 |
+
restarts = [250000, 500000, 750000]
|
106 |
+
restart_weights = [1, 1, 1]
|
107 |
+
|
108 |
+
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
|
109 |
+
clear_state=False)
|
110 |
+
|
111 |
+
##############################
|
112 |
+
# Cosine Annealing Restart
|
113 |
+
##############################
|
114 |
+
## two
|
115 |
+
T_period = [500000, 500000]
|
116 |
+
restarts = [500000]
|
117 |
+
restart_weights = [1]
|
118 |
+
|
119 |
+
## four
|
120 |
+
T_period = [250000, 250000, 250000, 250000]
|
121 |
+
restarts = [250000, 500000, 750000]
|
122 |
+
restart_weights = [1, 1, 1]
|
123 |
+
|
124 |
+
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
|
125 |
+
weights=restart_weights)
|
126 |
+
|
127 |
+
##############################
|
128 |
+
# Draw figure
|
129 |
+
##############################
|
130 |
+
N_iter = 1000000
|
131 |
+
lr_l = list(range(N_iter))
|
132 |
+
for i in range(N_iter):
|
133 |
+
scheduler.step()
|
134 |
+
current_lr = optimizer.param_groups[0]['lr']
|
135 |
+
lr_l[i] = current_lr
|
136 |
+
|
137 |
+
import matplotlib as mpl
|
138 |
+
from matplotlib import pyplot as plt
|
139 |
+
import matplotlib.ticker as mtick
|
140 |
+
|
141 |
+
mpl.style.use('default')
|
142 |
+
import seaborn
|
143 |
+
|
144 |
+
seaborn.set(style='whitegrid')
|
145 |
+
seaborn.set_context('paper')
|
146 |
+
|
147 |
+
plt.figure(1)
|
148 |
+
plt.subplot(111)
|
149 |
+
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
|
150 |
+
plt.title('Title', fontsize=16, color='k')
|
151 |
+
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
|
152 |
+
legend = plt.legend(loc='upper right', shadow=False)
|
153 |
+
ax = plt.gca()
|
154 |
+
labels = ax.get_xticks().tolist()
|
155 |
+
for k, v in enumerate(labels):
|
156 |
+
labels[k] = str(int(v / 1000)) + 'K'
|
157 |
+
ax.set_xticklabels(labels)
|
158 |
+
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
|
159 |
+
|
160 |
+
ax.set_ylabel('Learning rate')
|
161 |
+
ax.set_xlabel('Iteration')
|
162 |
+
fig = plt.gcf()
|
163 |
+
plt.show()
|
models/SRFlow/code/models/modules/FlowActNorms.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn as nn
|
19 |
+
|
20 |
+
from models.modules import thops
|
21 |
+
|
22 |
+
|
23 |
+
class _ActNorm(nn.Module):
|
24 |
+
"""
|
25 |
+
Activation Normalization
|
26 |
+
Initialize the bias and scale with a given minibatch,
|
27 |
+
so that the output per-channel have zero mean and unit variance for that.
|
28 |
+
|
29 |
+
After initialization, `bias` and `logs` will be trained as parameters.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, num_features, scale=1.):
|
33 |
+
super().__init__()
|
34 |
+
# register mean and scale
|
35 |
+
size = [1, num_features, 1, 1]
|
36 |
+
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
|
37 |
+
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
|
38 |
+
self.num_features = num_features
|
39 |
+
self.scale = float(scale)
|
40 |
+
self.inited = False
|
41 |
+
|
42 |
+
def _check_input_dim(self, input):
|
43 |
+
return NotImplemented
|
44 |
+
|
45 |
+
def initialize_parameters(self, input):
|
46 |
+
self._check_input_dim(input)
|
47 |
+
if not self.training:
|
48 |
+
return
|
49 |
+
if (self.bias != 0).any():
|
50 |
+
self.inited = True
|
51 |
+
return
|
52 |
+
assert input.device == self.bias.device, (input.device, self.bias.device)
|
53 |
+
with torch.no_grad():
|
54 |
+
bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
|
55 |
+
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
|
56 |
+
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
|
57 |
+
self.bias.data.copy_(bias.data)
|
58 |
+
self.logs.data.copy_(logs.data)
|
59 |
+
self.inited = True
|
60 |
+
|
61 |
+
def _center(self, input, reverse=False, offset=None):
|
62 |
+
bias = self.bias
|
63 |
+
|
64 |
+
if offset is not None:
|
65 |
+
bias = bias + offset
|
66 |
+
|
67 |
+
if not reverse:
|
68 |
+
return input + bias
|
69 |
+
else:
|
70 |
+
return input - bias
|
71 |
+
|
72 |
+
def _scale(self, input, logdet=None, reverse=False, offset=None):
|
73 |
+
logs = self.logs
|
74 |
+
|
75 |
+
if offset is not None:
|
76 |
+
logs = logs + offset
|
77 |
+
|
78 |
+
if not reverse:
|
79 |
+
input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
|
80 |
+
# input = input * torch.exp(logs+logs_offset)
|
81 |
+
else:
|
82 |
+
input = input * torch.exp(-logs)
|
83 |
+
if logdet is not None:
|
84 |
+
"""
|
85 |
+
logs is log_std of `mean of channels`
|
86 |
+
so we need to multiply pixels
|
87 |
+
"""
|
88 |
+
dlogdet = thops.sum(logs) * thops.pixels(input)
|
89 |
+
if reverse:
|
90 |
+
dlogdet *= -1
|
91 |
+
logdet = logdet + dlogdet
|
92 |
+
return input, logdet
|
93 |
+
|
94 |
+
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
|
95 |
+
if not self.inited:
|
96 |
+
self.initialize_parameters(input)
|
97 |
+
self._check_input_dim(input)
|
98 |
+
|
99 |
+
if offset_mask is not None:
|
100 |
+
logs_offset *= offset_mask
|
101 |
+
bias_offset *= offset_mask
|
102 |
+
# no need to permute dims as old version
|
103 |
+
if not reverse:
|
104 |
+
# center and scale
|
105 |
+
|
106 |
+
# self.input = input
|
107 |
+
input = self._center(input, reverse, bias_offset)
|
108 |
+
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
109 |
+
else:
|
110 |
+
# scale and center
|
111 |
+
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
112 |
+
input = self._center(input, reverse, bias_offset)
|
113 |
+
return input, logdet
|
114 |
+
|
115 |
+
|
116 |
+
class ActNorm2d(_ActNorm):
|
117 |
+
def __init__(self, num_features, scale=1.):
|
118 |
+
super().__init__(num_features, scale)
|
119 |
+
|
120 |
+
def _check_input_dim(self, input):
|
121 |
+
assert len(input.size()) == 4
|
122 |
+
assert input.size(1) == self.num_features, (
|
123 |
+
"[ActNorm]: input should be in shape as `BCHW`,"
|
124 |
+
" channels should be {} rather than {}".format(
|
125 |
+
self.num_features, input.size()))
|
126 |
+
|
127 |
+
|
128 |
+
class MaskedActNorm2d(ActNorm2d):
|
129 |
+
def __init__(self, num_features, scale=1.):
|
130 |
+
super().__init__(num_features, scale)
|
131 |
+
|
132 |
+
def forward(self, input, mask, logdet=None, reverse=False):
|
133 |
+
|
134 |
+
assert mask.dtype == torch.bool
|
135 |
+
output, logdet_out = super().forward(input, logdet, reverse)
|
136 |
+
|
137 |
+
input[mask] = output[mask]
|
138 |
+
logdet[mask] = logdet_out[mask]
|
139 |
+
|
140 |
+
return input, logdet
|
141 |
+
|
models/SRFlow/code/models/modules/FlowAffineCouplingsAblation.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn as nn
|
19 |
+
|
20 |
+
from models.modules import thops
|
21 |
+
from models.modules.flow import Conv2d, Conv2dZeros
|
22 |
+
from utils.util import opt_get
|
23 |
+
|
24 |
+
|
25 |
+
class CondAffineSeparatedAndCond(nn.Module):
|
26 |
+
def __init__(self, in_channels, opt):
|
27 |
+
super().__init__()
|
28 |
+
self.need_features = True
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.in_channels_rrdb = 320
|
31 |
+
self.kernel_hidden = 1
|
32 |
+
self.affine_eps = 0.0001
|
33 |
+
self.n_hidden_layers = 1
|
34 |
+
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
35 |
+
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
36 |
+
|
37 |
+
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
38 |
+
|
39 |
+
self.channels_for_nn = self.in_channels // 2
|
40 |
+
self.channels_for_co = self.in_channels - self.channels_for_nn
|
41 |
+
|
42 |
+
if self.channels_for_nn is None:
|
43 |
+
self.channels_for_nn = self.in_channels // 2
|
44 |
+
|
45 |
+
self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
|
46 |
+
out_channels=self.channels_for_co * 2,
|
47 |
+
hidden_channels=self.hidden_channels,
|
48 |
+
kernel_hidden=self.kernel_hidden,
|
49 |
+
n_hidden_layers=self.n_hidden_layers)
|
50 |
+
|
51 |
+
self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
|
52 |
+
out_channels=self.in_channels * 2,
|
53 |
+
hidden_channels=self.hidden_channels,
|
54 |
+
kernel_hidden=self.kernel_hidden,
|
55 |
+
n_hidden_layers=self.n_hidden_layers)
|
56 |
+
|
57 |
+
def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
|
58 |
+
if not reverse:
|
59 |
+
z = input
|
60 |
+
assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
|
61 |
+
|
62 |
+
# Feature Conditional
|
63 |
+
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
64 |
+
z = z + shiftFt
|
65 |
+
z = z * scaleFt
|
66 |
+
logdet = logdet + self.get_logdet(scaleFt)
|
67 |
+
|
68 |
+
# Self Conditional
|
69 |
+
z1, z2 = self.split(z)
|
70 |
+
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
71 |
+
self.asserts(scale, shift, z1, z2)
|
72 |
+
z2 = z2 + shift
|
73 |
+
z2 = z2 * scale
|
74 |
+
|
75 |
+
logdet = logdet + self.get_logdet(scale)
|
76 |
+
z = thops.cat_feature(z1, z2)
|
77 |
+
output = z
|
78 |
+
else:
|
79 |
+
z = input
|
80 |
+
|
81 |
+
# Self Conditional
|
82 |
+
z1, z2 = self.split(z)
|
83 |
+
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
84 |
+
self.asserts(scale, shift, z1, z2)
|
85 |
+
z2 = z2 / scale
|
86 |
+
z2 = z2 - shift
|
87 |
+
z = thops.cat_feature(z1, z2)
|
88 |
+
logdet = logdet - self.get_logdet(scale)
|
89 |
+
|
90 |
+
# Feature Conditional
|
91 |
+
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
92 |
+
z = z / scaleFt
|
93 |
+
z = z - shiftFt
|
94 |
+
logdet = logdet - self.get_logdet(scaleFt)
|
95 |
+
|
96 |
+
output = z
|
97 |
+
return output, logdet
|
98 |
+
|
99 |
+
def asserts(self, scale, shift, z1, z2):
|
100 |
+
assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
|
101 |
+
assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
|
102 |
+
assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
|
103 |
+
assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
|
104 |
+
|
105 |
+
def get_logdet(self, scale):
|
106 |
+
return thops.sum(torch.log(scale), dim=[1, 2, 3])
|
107 |
+
|
108 |
+
def feature_extract(self, z, f):
|
109 |
+
h = f(z)
|
110 |
+
shift, scale = thops.split_feature(h, "cross")
|
111 |
+
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
112 |
+
return scale, shift
|
113 |
+
|
114 |
+
def feature_extract_aff(self, z1, ft, f):
|
115 |
+
z = torch.cat([z1, ft], dim=1)
|
116 |
+
h = f(z)
|
117 |
+
shift, scale = thops.split_feature(h, "cross")
|
118 |
+
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
119 |
+
return scale, shift
|
120 |
+
|
121 |
+
def split(self, z):
|
122 |
+
z1 = z[:, :self.channels_for_nn]
|
123 |
+
z2 = z[:, self.channels_for_nn:]
|
124 |
+
assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
|
125 |
+
return z1, z2
|
126 |
+
|
127 |
+
def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
|
128 |
+
layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
|
129 |
+
|
130 |
+
for _ in range(n_hidden_layers):
|
131 |
+
layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
|
132 |
+
layers.append(nn.ReLU(inplace=False))
|
133 |
+
layers.append(Conv2dZeros(hidden_channels, out_channels))
|
134 |
+
|
135 |
+
return nn.Sequential(*layers)
|
models/SRFlow/code/models/modules/FlowStep.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn as nn
|
19 |
+
|
20 |
+
import models.modules
|
21 |
+
import models.modules.Permutations
|
22 |
+
from models.modules import flow, thops, FlowAffineCouplingsAblation
|
23 |
+
from utils.util import opt_get
|
24 |
+
|
25 |
+
|
26 |
+
def getConditional(rrdbResults, position):
|
27 |
+
img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
|
28 |
+
return img_ft
|
29 |
+
|
30 |
+
|
31 |
+
class FlowStep(nn.Module):
|
32 |
+
FlowPermutation = {
|
33 |
+
"reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
|
34 |
+
"shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
|
35 |
+
"invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
36 |
+
"squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
37 |
+
"resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
38 |
+
"resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
39 |
+
"InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
40 |
+
"InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
41 |
+
"InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
42 |
+
"InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
43 |
+
}
|
44 |
+
|
45 |
+
def __init__(self, in_channels, hidden_channels,
|
46 |
+
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
|
47 |
+
LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
|
48 |
+
position=None):
|
49 |
+
# check configures
|
50 |
+
assert flow_permutation in FlowStep.FlowPermutation, \
|
51 |
+
"float_permutation should be in `{}`".format(
|
52 |
+
FlowStep.FlowPermutation.keys())
|
53 |
+
super().__init__()
|
54 |
+
self.flow_permutation = flow_permutation
|
55 |
+
self.flow_coupling = flow_coupling
|
56 |
+
self.image_injector = image_injector
|
57 |
+
|
58 |
+
self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
|
59 |
+
self.position = normOpt['position'] if normOpt else None
|
60 |
+
|
61 |
+
self.in_shape = in_shape
|
62 |
+
self.position = position
|
63 |
+
self.acOpt = acOpt
|
64 |
+
|
65 |
+
# 1. actnorm
|
66 |
+
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
67 |
+
|
68 |
+
# 2. permute
|
69 |
+
if flow_permutation == "invconv":
|
70 |
+
self.invconv = models.modules.Permutations.InvertibleConv1x1(
|
71 |
+
in_channels, LU_decomposed=LU_decomposed)
|
72 |
+
|
73 |
+
# 3. coupling
|
74 |
+
if flow_coupling == "CondAffineSeparatedAndCond":
|
75 |
+
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
76 |
+
opt=opt)
|
77 |
+
elif flow_coupling == "noCoupling":
|
78 |
+
pass
|
79 |
+
else:
|
80 |
+
raise RuntimeError("coupling not Found:", flow_coupling)
|
81 |
+
|
82 |
+
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
|
83 |
+
if not reverse:
|
84 |
+
return self.normal_flow(input, logdet, rrdbResults)
|
85 |
+
else:
|
86 |
+
return self.reverse_flow(input, logdet, rrdbResults)
|
87 |
+
|
88 |
+
def normal_flow(self, z, logdet, rrdbResults=None):
|
89 |
+
if self.flow_coupling == "bentIdentityPreAct":
|
90 |
+
z, logdet = self.bentIdentPar(z, logdet, reverse=False)
|
91 |
+
|
92 |
+
# 1. actnorm
|
93 |
+
if self.norm_type == "ConditionalActNormImageInjector":
|
94 |
+
img_ft = getConditional(rrdbResults, self.position)
|
95 |
+
z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
|
96 |
+
elif self.norm_type == "noNorm":
|
97 |
+
pass
|
98 |
+
else:
|
99 |
+
z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
|
100 |
+
|
101 |
+
# 2. permute
|
102 |
+
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
103 |
+
self, z, logdet, False)
|
104 |
+
|
105 |
+
need_features = self.affine_need_features()
|
106 |
+
|
107 |
+
# 3. coupling
|
108 |
+
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
109 |
+
img_ft = getConditional(rrdbResults, self.position)
|
110 |
+
z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
|
111 |
+
return z, logdet
|
112 |
+
|
113 |
+
def reverse_flow(self, z, logdet, rrdbResults=None):
|
114 |
+
|
115 |
+
need_features = self.affine_need_features()
|
116 |
+
|
117 |
+
# 1.coupling
|
118 |
+
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
119 |
+
img_ft = getConditional(rrdbResults, self.position)
|
120 |
+
z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
|
121 |
+
|
122 |
+
# 2. permute
|
123 |
+
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
124 |
+
self, z, logdet, True)
|
125 |
+
|
126 |
+
# 3. actnorm
|
127 |
+
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
|
128 |
+
|
129 |
+
return z, logdet
|
130 |
+
|
131 |
+
def affine_need_features(self):
|
132 |
+
need_features = False
|
133 |
+
try:
|
134 |
+
need_features = self.affine.need_features
|
135 |
+
except:
|
136 |
+
pass
|
137 |
+
return need_features
|
models/SRFlow/code/models/modules/FlowUpsamplerNet.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn as nn
|
20 |
+
|
21 |
+
import models.modules.Split
|
22 |
+
from models.modules import flow, thops
|
23 |
+
from models.modules.Split import Split2d
|
24 |
+
from models.modules.glow_arch import f_conv2d_bias
|
25 |
+
from models.modules.FlowStep import FlowStep
|
26 |
+
from utils.util import opt_get
|
27 |
+
|
28 |
+
|
29 |
+
class FlowUpsamplerNet(nn.Module):
|
30 |
+
def __init__(self, image_shape, hidden_channels, K, L=None,
|
31 |
+
actnorm_scale=1.0,
|
32 |
+
flow_permutation=None,
|
33 |
+
flow_coupling="affine",
|
34 |
+
LU_decomposed=False, opt=None):
|
35 |
+
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.layers = nn.ModuleList()
|
39 |
+
self.output_shapes = []
|
40 |
+
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
|
41 |
+
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
|
42 |
+
if isinstance(self.K, int):
|
43 |
+
self.K = [K for K in [K, ] * (self.L + 1)]
|
44 |
+
|
45 |
+
self.opt = opt
|
46 |
+
H, W, self.C = image_shape
|
47 |
+
self.check_image_shape()
|
48 |
+
|
49 |
+
if opt['scale'] == 16:
|
50 |
+
self.levelToName = {
|
51 |
+
0: 'fea_up16',
|
52 |
+
1: 'fea_up8',
|
53 |
+
2: 'fea_up4',
|
54 |
+
3: 'fea_up2',
|
55 |
+
4: 'fea_up1',
|
56 |
+
}
|
57 |
+
|
58 |
+
if opt['scale'] == 8:
|
59 |
+
self.levelToName = {
|
60 |
+
0: 'fea_up8',
|
61 |
+
1: 'fea_up4',
|
62 |
+
2: 'fea_up2',
|
63 |
+
3: 'fea_up1',
|
64 |
+
4: 'fea_up0'
|
65 |
+
}
|
66 |
+
|
67 |
+
elif opt['scale'] == 4:
|
68 |
+
self.levelToName = {
|
69 |
+
0: 'fea_up4',
|
70 |
+
1: 'fea_up2',
|
71 |
+
2: 'fea_up1',
|
72 |
+
3: 'fea_up0',
|
73 |
+
4: 'fea_up-1'
|
74 |
+
}
|
75 |
+
|
76 |
+
affineInCh = self.get_affineInCh(opt_get)
|
77 |
+
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
78 |
+
|
79 |
+
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
|
80 |
+
|
81 |
+
conditional_channels = {}
|
82 |
+
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
83 |
+
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
|
84 |
+
conditional_channels[0] = n_rrdb
|
85 |
+
for level in range(1, self.L + 1):
|
86 |
+
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
87 |
+
# Level 2 gets conditionals from 3, 4
|
88 |
+
# Level 3 gets conditionals from 4
|
89 |
+
# Level 4 gets conditionals from None
|
90 |
+
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
|
91 |
+
conditional_channels[level] = n_rrdb + n_bypass
|
92 |
+
|
93 |
+
# Upsampler
|
94 |
+
for level in range(1, self.L + 1):
|
95 |
+
# 1. Squeeze
|
96 |
+
H, W = self.arch_squeeze(H, W)
|
97 |
+
|
98 |
+
# 2. K FlowStep
|
99 |
+
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
|
100 |
+
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
|
101 |
+
flow_permutation,
|
102 |
+
hidden_channels, normOpt, opt, opt_get,
|
103 |
+
n_conditinal_channels=conditional_channels[level])
|
104 |
+
# Split
|
105 |
+
self.arch_split(H, W, level, self.L, opt, opt_get)
|
106 |
+
|
107 |
+
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
|
108 |
+
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
109 |
+
else:
|
110 |
+
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
111 |
+
|
112 |
+
self.H = H
|
113 |
+
self.W = W
|
114 |
+
self.scaleH = 160 / H
|
115 |
+
self.scaleW = 160 / W
|
116 |
+
|
117 |
+
def get_n_rrdb_channels(self, opt, opt_get):
|
118 |
+
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
|
119 |
+
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
120 |
+
return n_rrdb
|
121 |
+
|
122 |
+
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
|
123 |
+
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
|
124 |
+
condAff = self.get_condAffSetting(opt, opt_get)
|
125 |
+
if condAff is not None:
|
126 |
+
condAff['in_channels_rrdb'] = n_conditinal_channels
|
127 |
+
|
128 |
+
for k in range(K):
|
129 |
+
position_name = get_position_name(H, self.opt['scale'])
|
130 |
+
if normOpt: normOpt['position'] = position_name
|
131 |
+
|
132 |
+
self.layers.append(
|
133 |
+
FlowStep(in_channels=self.C,
|
134 |
+
hidden_channels=hidden_channels,
|
135 |
+
actnorm_scale=actnorm_scale,
|
136 |
+
flow_permutation=flow_permutation,
|
137 |
+
flow_coupling=flow_coupling,
|
138 |
+
acOpt=condAff,
|
139 |
+
position=position_name,
|
140 |
+
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
|
141 |
+
self.output_shapes.append(
|
142 |
+
[-1, self.C, H, W])
|
143 |
+
|
144 |
+
def get_condAffSetting(self, opt, opt_get):
|
145 |
+
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
|
146 |
+
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
|
147 |
+
return condAff
|
148 |
+
|
149 |
+
def arch_split(self, H, W, L, levels, opt, opt_get):
|
150 |
+
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
|
151 |
+
correction = 0 if correct_splits else 1
|
152 |
+
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
|
153 |
+
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
|
154 |
+
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
|
155 |
+
position_name = get_position_name(H, self.opt['scale'])
|
156 |
+
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
|
157 |
+
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
|
158 |
+
cond_channels = 0 if cond_channels is None else cond_channels
|
159 |
+
|
160 |
+
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
|
161 |
+
|
162 |
+
if t == 'Split2d':
|
163 |
+
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
164 |
+
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
165 |
+
self.layers.append(split)
|
166 |
+
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
167 |
+
self.C = split.num_channels_pass
|
168 |
+
|
169 |
+
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
170 |
+
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
|
171 |
+
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
|
172 |
+
for _ in range(n_additionalFlowNoAffine):
|
173 |
+
self.layers.append(
|
174 |
+
FlowStep(in_channels=self.C,
|
175 |
+
hidden_channels=hidden_channels,
|
176 |
+
actnorm_scale=actnorm_scale,
|
177 |
+
flow_permutation='invconv',
|
178 |
+
flow_coupling='noCoupling',
|
179 |
+
LU_decomposed=LU_decomposed, opt=opt))
|
180 |
+
self.output_shapes.append(
|
181 |
+
[-1, self.C, H, W])
|
182 |
+
|
183 |
+
def arch_squeeze(self, H, W):
|
184 |
+
self.C, H, W = self.C * 4, H // 2, W // 2
|
185 |
+
self.layers.append(flow.SqueezeLayer(factor=2))
|
186 |
+
self.output_shapes.append([-1, self.C, H, W])
|
187 |
+
return H, W
|
188 |
+
|
189 |
+
def get_flow_permutation(self, flow_permutation, opt):
|
190 |
+
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
|
191 |
+
return flow_permutation
|
192 |
+
|
193 |
+
def get_affineInCh(self, opt_get):
|
194 |
+
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
195 |
+
affineInCh = (len(affineInCh) + 1) * 64
|
196 |
+
return affineInCh
|
197 |
+
|
198 |
+
def check_image_shape(self):
|
199 |
+
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
|
200 |
+
"self.C == 1 or self.C == 3")
|
201 |
+
|
202 |
+
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
|
203 |
+
y_onehot=None):
|
204 |
+
|
205 |
+
if reverse:
|
206 |
+
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
|
207 |
+
|
208 |
+
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
|
209 |
+
return sr, logdet
|
210 |
+
else:
|
211 |
+
assert gt is not None
|
212 |
+
assert rrdbResults is not None
|
213 |
+
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
214 |
+
|
215 |
+
return z, logdet
|
216 |
+
|
217 |
+
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
218 |
+
fl_fea = gt
|
219 |
+
reverse = False
|
220 |
+
level_conditionals = {}
|
221 |
+
bypasses = {}
|
222 |
+
|
223 |
+
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
|
224 |
+
|
225 |
+
for level in range(1, L + 1):
|
226 |
+
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
227 |
+
|
228 |
+
for layer, shape in zip(self.layers, self.output_shapes):
|
229 |
+
size = shape[2]
|
230 |
+
level = int(np.log(160 / size) / np.log(2))
|
231 |
+
|
232 |
+
if level > 0 and level not in level_conditionals.keys():
|
233 |
+
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
234 |
+
|
235 |
+
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
236 |
+
|
237 |
+
if isinstance(layer, FlowStep):
|
238 |
+
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
|
239 |
+
elif isinstance(layer, Split2d):
|
240 |
+
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
241 |
+
y_onehot=y_onehot)
|
242 |
+
else:
|
243 |
+
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
|
244 |
+
|
245 |
+
z = fl_fea
|
246 |
+
|
247 |
+
if not isinstance(epses, list):
|
248 |
+
return z, logdet
|
249 |
+
|
250 |
+
epses.append(z)
|
251 |
+
return epses, logdet
|
252 |
+
|
253 |
+
def forward_preFlow(self, fl_fea, logdet, reverse):
|
254 |
+
if hasattr(self, 'preFlow'):
|
255 |
+
for l in self.preFlow:
|
256 |
+
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
|
257 |
+
return fl_fea, logdet
|
258 |
+
|
259 |
+
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
260 |
+
ft = None if layer.position is None else rrdbResults[layer.position]
|
261 |
+
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
262 |
+
|
263 |
+
if isinstance(epses, list):
|
264 |
+
epses.append(eps)
|
265 |
+
return fl_fea, logdet
|
266 |
+
|
267 |
+
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
|
268 |
+
z = epses.pop() if isinstance(epses, list) else z
|
269 |
+
|
270 |
+
fl_fea = z
|
271 |
+
# debug.imwrite("fl_fea", fl_fea)
|
272 |
+
bypasses = {}
|
273 |
+
level_conditionals = {}
|
274 |
+
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
|
275 |
+
for level in range(self.L + 1):
|
276 |
+
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
277 |
+
|
278 |
+
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
279 |
+
size = shape[2]
|
280 |
+
level = int(np.log(160 / size) / np.log(2))
|
281 |
+
# size = fl_fea.shape[2]
|
282 |
+
# level = int(np.log(160 / size) / np.log(2))
|
283 |
+
|
284 |
+
if isinstance(layer, Split2d):
|
285 |
+
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
|
286 |
+
rrdbResults[self.levelToName[level]], logdet=logdet,
|
287 |
+
y_onehot=y_onehot)
|
288 |
+
elif isinstance(layer, FlowStep):
|
289 |
+
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
|
290 |
+
else:
|
291 |
+
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
|
292 |
+
|
293 |
+
sr = fl_fea
|
294 |
+
|
295 |
+
assert sr.shape[1] == 3
|
296 |
+
return sr, logdet
|
297 |
+
|
298 |
+
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
|
299 |
+
ft = None if layer.position is None else rrdbResults[layer.position]
|
300 |
+
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
|
301 |
+
eps=epses.pop() if isinstance(epses, list) else None,
|
302 |
+
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
|
303 |
+
return fl_fea, logdet
|
304 |
+
|
305 |
+
|
306 |
+
def get_position_name(H, scale):
|
307 |
+
downscale_factor = 160 // H
|
308 |
+
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
309 |
+
return position_name
|
models/SRFlow/code/models/modules/Permutations.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn as nn
|
20 |
+
from torch.nn import functional as F
|
21 |
+
|
22 |
+
from models.modules import thops
|
23 |
+
|
24 |
+
|
25 |
+
class InvertibleConv1x1(nn.Module):
|
26 |
+
def __init__(self, num_channels, LU_decomposed=False):
|
27 |
+
super().__init__()
|
28 |
+
w_shape = [num_channels, num_channels]
|
29 |
+
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
|
30 |
+
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
31 |
+
self.w_shape = w_shape
|
32 |
+
self.LU = LU_decomposed
|
33 |
+
|
34 |
+
def get_weight(self, input, reverse):
|
35 |
+
w_shape = self.w_shape
|
36 |
+
pixels = thops.pixels(input)
|
37 |
+
dlogdet = torch.slogdet(self.weight)[1] * pixels
|
38 |
+
if not reverse:
|
39 |
+
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
|
40 |
+
else:
|
41 |
+
weight = torch.inverse(self.weight.double()).float() \
|
42 |
+
.view(w_shape[0], w_shape[1], 1, 1)
|
43 |
+
return weight, dlogdet
|
44 |
+
def forward(self, input, logdet=None, reverse=False):
|
45 |
+
"""
|
46 |
+
log-det = log|abs(|W|)| * pixels
|
47 |
+
"""
|
48 |
+
weight, dlogdet = self.get_weight(input, reverse)
|
49 |
+
if not reverse:
|
50 |
+
z = F.conv2d(input, weight)
|
51 |
+
if logdet is not None:
|
52 |
+
logdet = logdet + dlogdet
|
53 |
+
return z, logdet
|
54 |
+
else:
|
55 |
+
z = F.conv2d(input, weight)
|
56 |
+
if logdet is not None:
|
57 |
+
logdet = logdet - dlogdet
|
58 |
+
return z, logdet
|
models/SRFlow/code/models/modules/RRDBNet_arch.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import models.modules.module_util as mutil
|
22 |
+
from utils.util import opt_get
|
23 |
+
|
24 |
+
|
25 |
+
class ResidualDenseBlock_5C(nn.Module):
|
26 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
27 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
28 |
+
# gc: growth channel, i.e. intermediate channels
|
29 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
30 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
31 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
32 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
33 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
34 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
35 |
+
|
36 |
+
# initialization
|
37 |
+
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x1 = self.lrelu(self.conv1(x))
|
41 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
42 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
43 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
44 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
45 |
+
return x5 * 0.2 + x
|
46 |
+
|
47 |
+
|
48 |
+
class RRDB(nn.Module):
|
49 |
+
'''Residual in Residual Dense Block'''
|
50 |
+
|
51 |
+
def __init__(self, nf, gc=32):
|
52 |
+
super(RRDB, self).__init__()
|
53 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
54 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
55 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
out = self.RDB1(x)
|
59 |
+
out = self.RDB2(out)
|
60 |
+
out = self.RDB3(out)
|
61 |
+
return out * 0.2 + x
|
62 |
+
|
63 |
+
|
64 |
+
class RRDBNet(nn.Module):
|
65 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
66 |
+
self.opt = opt
|
67 |
+
super(RRDBNet, self).__init__()
|
68 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
69 |
+
self.scale = scale
|
70 |
+
|
71 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
72 |
+
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
73 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
74 |
+
#### upsampling
|
75 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
76 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
77 |
+
if self.scale >= 8:
|
78 |
+
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
79 |
+
if self.scale >= 16:
|
80 |
+
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
81 |
+
if self.scale >= 32:
|
82 |
+
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
83 |
+
|
84 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
85 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
86 |
+
|
87 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
88 |
+
|
89 |
+
def forward(self, x, get_steps=False):
|
90 |
+
fea = self.conv_first(x)
|
91 |
+
|
92 |
+
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
93 |
+
block_results = {}
|
94 |
+
|
95 |
+
for idx, m in enumerate(self.RRDB_trunk.children()):
|
96 |
+
fea = m(fea)
|
97 |
+
for b in block_idxs:
|
98 |
+
if b == idx:
|
99 |
+
block_results["block_{}".format(idx)] = fea
|
100 |
+
|
101 |
+
trunk = self.trunk_conv(fea)
|
102 |
+
|
103 |
+
last_lr_fea = fea + trunk
|
104 |
+
|
105 |
+
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
106 |
+
fea = self.lrelu(fea_up2)
|
107 |
+
|
108 |
+
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
109 |
+
fea = self.lrelu(fea_up4)
|
110 |
+
|
111 |
+
fea_up8 = None
|
112 |
+
fea_up16 = None
|
113 |
+
fea_up32 = None
|
114 |
+
|
115 |
+
if self.scale >= 8:
|
116 |
+
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
117 |
+
fea = self.lrelu(fea_up8)
|
118 |
+
if self.scale >= 16:
|
119 |
+
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
120 |
+
fea = self.lrelu(fea_up16)
|
121 |
+
if self.scale >= 32:
|
122 |
+
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
123 |
+
fea = self.lrelu(fea_up32)
|
124 |
+
|
125 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
126 |
+
|
127 |
+
results = {'last_lr_fea': last_lr_fea,
|
128 |
+
'fea_up1': last_lr_fea,
|
129 |
+
'fea_up2': fea_up2,
|
130 |
+
'fea_up4': fea_up4,
|
131 |
+
'fea_up8': fea_up8,
|
132 |
+
'fea_up16': fea_up16,
|
133 |
+
'fea_up32': fea_up32,
|
134 |
+
'out': out}
|
135 |
+
|
136 |
+
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
|
137 |
+
if fea_up0_en:
|
138 |
+
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
139 |
+
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
|
140 |
+
if fea_upn1_en:
|
141 |
+
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
142 |
+
|
143 |
+
if get_steps:
|
144 |
+
for k, v in block_results.items():
|
145 |
+
results[k] = v
|
146 |
+
return results
|
147 |
+
else:
|
148 |
+
return out
|
models/SRFlow/code/models/modules/SRFlowNet_arch.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import math
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import numpy as np
|
23 |
+
from models.modules.RRDBNet_arch import RRDBNet
|
24 |
+
from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
|
25 |
+
import models.modules.thops as thops
|
26 |
+
import models.modules.flow as flow
|
27 |
+
from utils.util import opt_get
|
28 |
+
|
29 |
+
|
30 |
+
class SRFlowNet(nn.Module):
|
31 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
|
32 |
+
super(SRFlowNet, self).__init__()
|
33 |
+
|
34 |
+
self.opt = opt
|
35 |
+
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
36 |
+
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
37 |
+
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
38 |
+
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
|
39 |
+
hidden_channels = hidden_channels or 64
|
40 |
+
self.RRDB_training = True # Default is true
|
41 |
+
|
42 |
+
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
43 |
+
set_RRDB_to_train = False
|
44 |
+
if set_RRDB_to_train:
|
45 |
+
self.set_rrdb_training(True)
|
46 |
+
|
47 |
+
self.flowUpsamplerNet = \
|
48 |
+
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
49 |
+
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
|
50 |
+
self.i = 0
|
51 |
+
|
52 |
+
def set_rrdb_training(self, trainable):
|
53 |
+
if self.RRDB_training != trainable:
|
54 |
+
for p in self.RRDB.parameters():
|
55 |
+
p.requires_grad = trainable
|
56 |
+
self.RRDB_training = trainable
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
61 |
+
lr_enc=None,
|
62 |
+
add_gt_noise=False, step=None, y_label=None):
|
63 |
+
if not reverse:
|
64 |
+
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
65 |
+
y_onehot=y_label)
|
66 |
+
else:
|
67 |
+
# assert lr.shape[0] == 1
|
68 |
+
assert lr.shape[1] == 3
|
69 |
+
# assert lr.shape[2] == 20
|
70 |
+
# assert lr.shape[3] == 20
|
71 |
+
# assert z.shape[0] == 1
|
72 |
+
# assert z.shape[1] == 3 * 8 * 8
|
73 |
+
# assert z.shape[2] == 20
|
74 |
+
# assert z.shape[3] == 20
|
75 |
+
if reverse_with_grad:
|
76 |
+
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
77 |
+
add_gt_noise=add_gt_noise)
|
78 |
+
else:
|
79 |
+
with torch.no_grad():
|
80 |
+
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
81 |
+
add_gt_noise=add_gt_noise)
|
82 |
+
|
83 |
+
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
84 |
+
if lr_enc is None:
|
85 |
+
lr_enc = self.rrdbPreprocessing(lr)
|
86 |
+
|
87 |
+
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
88 |
+
pixels = thops.pixels(gt)
|
89 |
+
|
90 |
+
z = gt
|
91 |
+
|
92 |
+
if add_gt_noise:
|
93 |
+
# Setup
|
94 |
+
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
|
95 |
+
if noiseQuant:
|
96 |
+
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
97 |
+
logdet = logdet + float(-np.log(self.quant) * pixels)
|
98 |
+
|
99 |
+
# Encode
|
100 |
+
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
|
101 |
+
y_onehot=y_onehot)
|
102 |
+
|
103 |
+
objective = logdet.clone()
|
104 |
+
|
105 |
+
if isinstance(epses, (list, tuple)):
|
106 |
+
z = epses[-1]
|
107 |
+
else:
|
108 |
+
z = epses
|
109 |
+
|
110 |
+
objective = objective + flow.GaussianDiag.logp(None, None, z)
|
111 |
+
|
112 |
+
nll = (-objective) / float(np.log(2.) * pixels)
|
113 |
+
|
114 |
+
if isinstance(epses, list):
|
115 |
+
return epses, nll, logdet
|
116 |
+
return z, nll, logdet
|
117 |
+
|
118 |
+
def rrdbPreprocessing(self, lr):
|
119 |
+
rrdbResults = self.RRDB(lr, get_steps=True)
|
120 |
+
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
121 |
+
if len(block_idxs) > 0:
|
122 |
+
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
123 |
+
|
124 |
+
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
|
125 |
+
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
126 |
+
if 'fea_up0' in rrdbResults.keys():
|
127 |
+
keys.append('fea_up0')
|
128 |
+
if 'fea_up-1' in rrdbResults.keys():
|
129 |
+
keys.append('fea_up-1')
|
130 |
+
if self.opt['scale'] >= 8:
|
131 |
+
keys.append('fea_up8')
|
132 |
+
if self.opt['scale'] == 16:
|
133 |
+
keys.append('fea_up16')
|
134 |
+
for k in keys:
|
135 |
+
h = rrdbResults[k].shape[2]
|
136 |
+
w = rrdbResults[k].shape[3]
|
137 |
+
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
|
138 |
+
return rrdbResults
|
139 |
+
|
140 |
+
def get_score(self, disc_loss_sigma, z):
|
141 |
+
score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
|
142 |
+
z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
|
143 |
+
return -score_real
|
144 |
+
|
145 |
+
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
146 |
+
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
147 |
+
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
148 |
+
|
149 |
+
if add_gt_noise:
|
150 |
+
logdet = logdet - float(-np.log(self.quant) * pixels)
|
151 |
+
|
152 |
+
if lr_enc is None:
|
153 |
+
lr_enc = self.rrdbPreprocessing(lr)
|
154 |
+
|
155 |
+
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
156 |
+
logdet=logdet)
|
157 |
+
|
158 |
+
return x, logdet
|
models/SRFlow/code/models/modules/Split.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn as nn
|
19 |
+
|
20 |
+
from models.modules import thops
|
21 |
+
from models.modules.FlowStep import FlowStep
|
22 |
+
from models.modules.flow import Conv2dZeros, GaussianDiag
|
23 |
+
from utils.util import opt_get
|
24 |
+
|
25 |
+
|
26 |
+
class Split2d(nn.Module):
|
27 |
+
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_channels_consume = int(round(num_channels * consume_ratio))
|
31 |
+
self.num_channels_pass = num_channels - self.num_channels_consume
|
32 |
+
|
33 |
+
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
|
34 |
+
out_channels=self.num_channels_consume * 2)
|
35 |
+
self.logs_eps = logs_eps
|
36 |
+
self.position = position
|
37 |
+
self.opt = opt
|
38 |
+
|
39 |
+
def split2d_prior(self, z, ft):
|
40 |
+
if ft is not None:
|
41 |
+
z = torch.cat([z, ft], dim=1)
|
42 |
+
h = self.conv(z)
|
43 |
+
return thops.split_feature(h, "cross")
|
44 |
+
|
45 |
+
def exp_eps(self, logs):
|
46 |
+
return torch.exp(logs) + self.logs_eps
|
47 |
+
|
48 |
+
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
|
49 |
+
if not reverse:
|
50 |
+
# self.input = input
|
51 |
+
z1, z2 = self.split_ratio(input)
|
52 |
+
mean, logs = self.split2d_prior(z1, ft)
|
53 |
+
|
54 |
+
eps = (z2 - mean) / self.exp_eps(logs)
|
55 |
+
|
56 |
+
logdet = logdet + self.get_logdet(logs, mean, z2)
|
57 |
+
|
58 |
+
# print(logs.shape, mean.shape, z2.shape)
|
59 |
+
# self.eps = eps
|
60 |
+
# print('split, enc eps:', eps)
|
61 |
+
return z1, logdet, eps
|
62 |
+
else:
|
63 |
+
z1 = input
|
64 |
+
mean, logs = self.split2d_prior(z1, ft)
|
65 |
+
|
66 |
+
if eps is None:
|
67 |
+
#print("WARNING: eps is None, generating eps untested functionality!")
|
68 |
+
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
69 |
+
|
70 |
+
eps = eps.to(mean.device)
|
71 |
+
z2 = mean + self.exp_eps(logs) * eps
|
72 |
+
|
73 |
+
z = thops.cat_feature(z1, z2)
|
74 |
+
logdet = logdet - self.get_logdet(logs, mean, z2)
|
75 |
+
|
76 |
+
return z, logdet
|
77 |
+
# return z, logdet, eps
|
78 |
+
|
79 |
+
def get_logdet(self, logs, mean, z2):
|
80 |
+
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
81 |
+
# print("Split2D: logdet diff", logdet_diff.item())
|
82 |
+
return logdet_diff
|
83 |
+
|
84 |
+
def split_ratio(self, input):
|
85 |
+
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
86 |
+
return z1, z2
|
models/SRFlow/code/models/modules/__init__.py
ADDED
File without changes
|
models/SRFlow/code/models/modules/flow.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from models.modules.FlowActNorms import ActNorm2d
|
23 |
+
from . import thops
|
24 |
+
|
25 |
+
|
26 |
+
class Conv2d(nn.Conv2d):
|
27 |
+
pad_dict = {
|
28 |
+
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
|
29 |
+
"valid": lambda kernel, stride: [0 for _ in kernel]
|
30 |
+
}
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def get_padding(padding, kernel_size, stride):
|
34 |
+
# make paddding
|
35 |
+
if isinstance(padding, str):
|
36 |
+
if isinstance(kernel_size, int):
|
37 |
+
kernel_size = [kernel_size, kernel_size]
|
38 |
+
if isinstance(stride, int):
|
39 |
+
stride = [stride, stride]
|
40 |
+
padding = padding.lower()
|
41 |
+
try:
|
42 |
+
padding = Conv2d.pad_dict[padding](kernel_size, stride)
|
43 |
+
except KeyError:
|
44 |
+
raise ValueError("{} is not supported".format(padding))
|
45 |
+
return padding
|
46 |
+
|
47 |
+
def __init__(self, in_channels, out_channels,
|
48 |
+
kernel_size=[3, 3], stride=[1, 1],
|
49 |
+
padding="same", do_actnorm=True, weight_std=0.05):
|
50 |
+
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
51 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
52 |
+
padding, bias=(not do_actnorm))
|
53 |
+
# init weight with std
|
54 |
+
self.weight.data.normal_(mean=0.0, std=weight_std)
|
55 |
+
if not do_actnorm:
|
56 |
+
self.bias.data.zero_()
|
57 |
+
else:
|
58 |
+
self.actnorm = ActNorm2d(out_channels)
|
59 |
+
self.do_actnorm = do_actnorm
|
60 |
+
|
61 |
+
def forward(self, input):
|
62 |
+
x = super().forward(input)
|
63 |
+
if self.do_actnorm:
|
64 |
+
x, _ = self.actnorm(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Conv2dZeros(nn.Conv2d):
|
69 |
+
def __init__(self, in_channels, out_channels,
|
70 |
+
kernel_size=[3, 3], stride=[1, 1],
|
71 |
+
padding="same", logscale_factor=3):
|
72 |
+
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
73 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
|
74 |
+
# logscale_factor
|
75 |
+
self.logscale_factor = logscale_factor
|
76 |
+
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
|
77 |
+
# init
|
78 |
+
self.weight.data.zero_()
|
79 |
+
self.bias.data.zero_()
|
80 |
+
|
81 |
+
def forward(self, input):
|
82 |
+
output = super().forward(input)
|
83 |
+
return output * torch.exp(self.logs * self.logscale_factor)
|
84 |
+
|
85 |
+
|
86 |
+
class GaussianDiag:
|
87 |
+
Log2PI = float(np.log(2 * np.pi))
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def likelihood(mean, logs, x):
|
91 |
+
"""
|
92 |
+
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
|
93 |
+
k = 1 (Independent)
|
94 |
+
Var = logs ** 2
|
95 |
+
"""
|
96 |
+
if mean is None and logs is None:
|
97 |
+
return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
|
98 |
+
else:
|
99 |
+
return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def logp(mean, logs, x):
|
103 |
+
likelihood = GaussianDiag.likelihood(mean, logs, x)
|
104 |
+
return thops.sum(likelihood, dim=[1, 2, 3])
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def sample(mean, logs, eps_std=None):
|
108 |
+
eps_std = eps_std or 1
|
109 |
+
eps = torch.normal(mean=torch.zeros_like(mean),
|
110 |
+
std=torch.ones_like(logs) * eps_std)
|
111 |
+
return mean + torch.exp(logs) * eps
|
112 |
+
|
113 |
+
@staticmethod
|
114 |
+
def sample_eps(shape, eps_std, seed=None):
|
115 |
+
if seed is not None:
|
116 |
+
torch.manual_seed(seed)
|
117 |
+
eps = torch.normal(mean=torch.zeros(shape),
|
118 |
+
std=torch.ones(shape) * eps_std)
|
119 |
+
return eps
|
120 |
+
|
121 |
+
|
122 |
+
def squeeze2d(input, factor=2):
|
123 |
+
assert factor >= 1 and isinstance(factor, int)
|
124 |
+
if factor == 1:
|
125 |
+
return input
|
126 |
+
size = input.size()
|
127 |
+
B = size[0]
|
128 |
+
C = size[1]
|
129 |
+
H = size[2]
|
130 |
+
W = size[3]
|
131 |
+
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
|
132 |
+
x = input.view(B, C, H // factor, factor, W // factor, factor)
|
133 |
+
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
134 |
+
x = x.view(B, C * factor * factor, H // factor, W // factor)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
def unsqueeze2d(input, factor=2):
|
139 |
+
assert factor >= 1 and isinstance(factor, int)
|
140 |
+
factor2 = factor ** 2
|
141 |
+
if factor == 1:
|
142 |
+
return input
|
143 |
+
size = input.size()
|
144 |
+
B = size[0]
|
145 |
+
C = size[1]
|
146 |
+
H = size[2]
|
147 |
+
W = size[3]
|
148 |
+
assert C % (factor2) == 0, "{}".format(C)
|
149 |
+
x = input.view(B, C // factor2, factor, factor, H, W)
|
150 |
+
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
151 |
+
x = x.view(B, C // (factor2), H * factor, W * factor)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class SqueezeLayer(nn.Module):
|
156 |
+
def __init__(self, factor):
|
157 |
+
super().__init__()
|
158 |
+
self.factor = factor
|
159 |
+
|
160 |
+
def forward(self, input, logdet=None, reverse=False):
|
161 |
+
if not reverse:
|
162 |
+
output = squeeze2d(input, self.factor) # Squeeze in forward
|
163 |
+
return output, logdet
|
164 |
+
else:
|
165 |
+
output = unsqueeze2d(input, self.factor)
|
166 |
+
return output, logdet
|
models/SRFlow/code/models/modules/glow_arch.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
def f_conv2d_bias(in_channels, out_channels):
|
21 |
+
def padding_same(kernel, stride):
|
22 |
+
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
|
23 |
+
|
24 |
+
padding = padding_same([3, 3], [1, 1])
|
25 |
+
assert padding == [1, 1], padding
|
26 |
+
return nn.Sequential(
|
27 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
|
28 |
+
bias=True))
|
models/SRFlow/code/models/modules/loss.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
|
21 |
+
class CharbonnierLoss(nn.Module):
|
22 |
+
"""Charbonnier Loss (L1)"""
|
23 |
+
|
24 |
+
def __init__(self, eps=1e-6):
|
25 |
+
super(CharbonnierLoss, self).__init__()
|
26 |
+
self.eps = eps
|
27 |
+
|
28 |
+
def forward(self, x, y):
|
29 |
+
diff = x - y
|
30 |
+
loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
31 |
+
return loss
|
32 |
+
|
33 |
+
|
34 |
+
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
35 |
+
class GANLoss(nn.Module):
|
36 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
37 |
+
super(GANLoss, self).__init__()
|
38 |
+
self.gan_type = gan_type.lower()
|
39 |
+
self.real_label_val = real_label_val
|
40 |
+
self.fake_label_val = fake_label_val
|
41 |
+
|
42 |
+
if self.gan_type == 'gan' or self.gan_type == 'ragan':
|
43 |
+
self.loss = nn.BCEWithLogitsLoss()
|
44 |
+
elif self.gan_type == 'lsgan':
|
45 |
+
self.loss = nn.MSELoss()
|
46 |
+
elif self.gan_type == 'wgan-gp':
|
47 |
+
|
48 |
+
def wgan_loss(input, target):
|
49 |
+
# target is boolean
|
50 |
+
return -1 * input.mean() if target else input.mean()
|
51 |
+
|
52 |
+
self.loss = wgan_loss
|
53 |
+
else:
|
54 |
+
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
55 |
+
|
56 |
+
def get_target_label(self, input, target_is_real):
|
57 |
+
if self.gan_type == 'wgan-gp':
|
58 |
+
return target_is_real
|
59 |
+
if target_is_real:
|
60 |
+
return torch.empty_like(input).fill_(self.real_label_val)
|
61 |
+
else:
|
62 |
+
return torch.empty_like(input).fill_(self.fake_label_val)
|
63 |
+
|
64 |
+
def forward(self, input, target_is_real):
|
65 |
+
target_label = self.get_target_label(input, target_is_real)
|
66 |
+
loss = self.loss(input, target_label)
|
67 |
+
return loss
|
68 |
+
|
69 |
+
|
70 |
+
class GradientPenaltyLoss(nn.Module):
|
71 |
+
def __init__(self, device=torch.device('cpu')):
|
72 |
+
super(GradientPenaltyLoss, self).__init__()
|
73 |
+
self.register_buffer('grad_outputs', torch.Tensor())
|
74 |
+
self.grad_outputs = self.grad_outputs.to(device)
|
75 |
+
|
76 |
+
def get_grad_outputs(self, input):
|
77 |
+
if self.grad_outputs.size() != input.size():
|
78 |
+
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
79 |
+
return self.grad_outputs
|
80 |
+
|
81 |
+
def forward(self, interp, interp_crit):
|
82 |
+
grad_outputs = self.get_grad_outputs(interp_crit)
|
83 |
+
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
|
84 |
+
grad_outputs=grad_outputs, create_graph=True,
|
85 |
+
retain_graph=True, only_inputs=True)[0]
|
86 |
+
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
87 |
+
grad_interp_norm = grad_interp.norm(2, dim=1)
|
88 |
+
|
89 |
+
loss = ((grad_interp_norm - 1)**2).mean()
|
90 |
+
return loss
|
models/SRFlow/code/models/modules/module_util.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.init as init
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
def initialize_weights(net_l, scale=1):
|
24 |
+
if not isinstance(net_l, list):
|
25 |
+
net_l = [net_l]
|
26 |
+
for net in net_l:
|
27 |
+
for m in net.modules():
|
28 |
+
if isinstance(m, nn.Conv2d):
|
29 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
30 |
+
m.weight.data *= scale # for residual block
|
31 |
+
if m.bias is not None:
|
32 |
+
m.bias.data.zero_()
|
33 |
+
elif isinstance(m, nn.Linear):
|
34 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
35 |
+
m.weight.data *= scale
|
36 |
+
if m.bias is not None:
|
37 |
+
m.bias.data.zero_()
|
38 |
+
elif isinstance(m, nn.BatchNorm2d):
|
39 |
+
init.constant_(m.weight, 1)
|
40 |
+
init.constant_(m.bias.data, 0.0)
|
41 |
+
|
42 |
+
|
43 |
+
def make_layer(block, n_layers):
|
44 |
+
layers = []
|
45 |
+
for _ in range(n_layers):
|
46 |
+
layers.append(block())
|
47 |
+
return nn.Sequential(*layers)
|
48 |
+
|
49 |
+
|
50 |
+
class ResidualBlock_noBN(nn.Module):
|
51 |
+
'''Residual block w/o BN
|
52 |
+
---Conv-ReLU-Conv-+-
|
53 |
+
|________________|
|
54 |
+
'''
|
55 |
+
|
56 |
+
def __init__(self, nf=64):
|
57 |
+
super(ResidualBlock_noBN, self).__init__()
|
58 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
59 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
60 |
+
|
61 |
+
# initialization
|
62 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
identity = x
|
66 |
+
out = F.relu(self.conv1(x), inplace=True)
|
67 |
+
out = self.conv2(out)
|
68 |
+
return identity + out
|
69 |
+
|
70 |
+
|
71 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
72 |
+
"""Warp an image or feature map with optical flow
|
73 |
+
Args:
|
74 |
+
x (Tensor): size (N, C, H, W)
|
75 |
+
flow (Tensor): size (N, H, W, 2), normal value
|
76 |
+
interp_mode (str): 'nearest' or 'bilinear'
|
77 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Tensor: warped image or feature map
|
81 |
+
"""
|
82 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
83 |
+
B, C, H, W = x.size()
|
84 |
+
# mesh grid
|
85 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
86 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
87 |
+
grid.requires_grad = False
|
88 |
+
grid = grid.type_as(x)
|
89 |
+
vgrid = grid + flow
|
90 |
+
# scale grid to [-1,1]
|
91 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
92 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
93 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
94 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
95 |
+
return output
|
models/SRFlow/code/models/modules/thops.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
def sum(tensor, dim=None, keepdim=False):
|
21 |
+
if dim is None:
|
22 |
+
# sum up all dim
|
23 |
+
return torch.sum(tensor)
|
24 |
+
else:
|
25 |
+
if isinstance(dim, int):
|
26 |
+
dim = [dim]
|
27 |
+
dim = sorted(dim)
|
28 |
+
for d in dim:
|
29 |
+
tensor = tensor.sum(dim=d, keepdim=True)
|
30 |
+
if not keepdim:
|
31 |
+
for i, d in enumerate(dim):
|
32 |
+
tensor.squeeze_(d-i)
|
33 |
+
return tensor
|
34 |
+
|
35 |
+
|
36 |
+
def mean(tensor, dim=None, keepdim=False):
|
37 |
+
if dim is None:
|
38 |
+
# mean all dim
|
39 |
+
return torch.mean(tensor)
|
40 |
+
else:
|
41 |
+
if isinstance(dim, int):
|
42 |
+
dim = [dim]
|
43 |
+
dim = sorted(dim)
|
44 |
+
for d in dim:
|
45 |
+
tensor = tensor.mean(dim=d, keepdim=True)
|
46 |
+
if not keepdim:
|
47 |
+
for i, d in enumerate(dim):
|
48 |
+
tensor.squeeze_(d-i)
|
49 |
+
return tensor
|
50 |
+
|
51 |
+
|
52 |
+
def split_feature(tensor, type="split"):
|
53 |
+
"""
|
54 |
+
type = ["split", "cross"]
|
55 |
+
"""
|
56 |
+
C = tensor.size(1)
|
57 |
+
if type == "split":
|
58 |
+
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
|
59 |
+
elif type == "cross":
|
60 |
+
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
|
61 |
+
|
62 |
+
|
63 |
+
def cat_feature(tensor_a, tensor_b):
|
64 |
+
return torch.cat((tensor_a, tensor_b), dim=1)
|
65 |
+
|
66 |
+
|
67 |
+
def pixels(tensor):
|
68 |
+
return int(tensor.size(2) * tensor.size(3))
|
models/SRFlow/code/models/networks.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import logging
|
21 |
+
import models.modules.RRDBNet_arch as RRDBNet_arch
|
22 |
+
|
23 |
+
logger = logging.getLogger('base')
|
24 |
+
|
25 |
+
|
26 |
+
def find_model_using_name(model_name):
|
27 |
+
model_filename = "models.modules." + model_name + "_arch"
|
28 |
+
modellib = importlib.import_module(model_filename)
|
29 |
+
|
30 |
+
model = None
|
31 |
+
target_model_name = model_name.replace('_Net', '')
|
32 |
+
for name, cls in modellib.__dict__.items():
|
33 |
+
if name.lower() == target_model_name.lower():
|
34 |
+
model = cls
|
35 |
+
|
36 |
+
if model is None:
|
37 |
+
print(
|
38 |
+
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
|
39 |
+
model_filename, target_model_name))
|
40 |
+
exit(0)
|
41 |
+
|
42 |
+
return model
|
43 |
+
|
44 |
+
|
45 |
+
####################
|
46 |
+
# define network
|
47 |
+
####################
|
48 |
+
#### Generator
|
49 |
+
def define_G(opt):
|
50 |
+
opt_net = opt['network_G']
|
51 |
+
which_model = opt_net['which_model_G']
|
52 |
+
|
53 |
+
if which_model == 'RRDBNet':
|
54 |
+
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
55 |
+
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], opt=opt)
|
56 |
+
elif which_model == 'EDSRNet':
|
57 |
+
Arch = find_model_using_name(which_model)
|
58 |
+
netG = Arch(scale=opt['scale'])
|
59 |
+
elif which_model == 'rankSRGAN':
|
60 |
+
Arch = find_model_using_name(which_model)
|
61 |
+
netG = Arch(upscale=opt['scale'])
|
62 |
+
# elif which_model == 'sft_arch': # SFT-GAN
|
63 |
+
# netG = sft_arch.SFT_Net()
|
64 |
+
else:
|
65 |
+
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
66 |
+
return netG
|
67 |
+
|
68 |
+
|
69 |
+
def define_Flow(opt, step):
|
70 |
+
opt_net = opt['network_G']
|
71 |
+
which_model = opt_net['which_model_G']
|
72 |
+
|
73 |
+
Arch = find_model_using_name(which_model)
|
74 |
+
netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
75 |
+
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step)
|
76 |
+
|
77 |
+
return netG
|
78 |
+
|
79 |
+
|
80 |
+
#### Discriminator
|
81 |
+
def define_D(opt):
|
82 |
+
opt_net = opt['network_D']
|
83 |
+
which_model = opt_net['which_model_D']
|
84 |
+
|
85 |
+
if which_model == 'discriminator_vgg_128':
|
86 |
+
hidden_units = opt_net.get('hidden_units', 8192)
|
87 |
+
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], hidden_units=hidden_units)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
90 |
+
return netD
|
91 |
+
|
92 |
+
|
93 |
+
#### Define Network used for Perceptual Loss
|
94 |
+
def define_F(opt, use_bn=False):
|
95 |
+
gpu_ids = opt.get('gpu_ids', None)
|
96 |
+
device = torch.device('cuda' if gpu_ids else 'cpu')
|
97 |
+
# PyTorch pretrained_models VGG19-54, before ReLU.
|
98 |
+
if use_bn:
|
99 |
+
feature_layer = 49
|
100 |
+
else:
|
101 |
+
feature_layer = 34
|
102 |
+
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
103 |
+
use_input_norm=True, device=device)
|
104 |
+
netF.eval() # No need to train
|
105 |
+
return netF
|
models/SRFlow/code/options/__init__.py
ADDED
File without changes
|
models/SRFlow/code/options/options.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import os
|
18 |
+
import os.path as osp
|
19 |
+
import logging
|
20 |
+
import yaml
|
21 |
+
from utils.util import OrderedYaml
|
22 |
+
|
23 |
+
Loader, Dumper = OrderedYaml()
|
24 |
+
|
25 |
+
|
26 |
+
def parse(opt_path, is_train=True):
|
27 |
+
with open(opt_path, mode='r') as f:
|
28 |
+
opt = yaml.load(f, Loader=Loader)
|
29 |
+
# export CUDA_VISIBLE_DEVICES
|
30 |
+
gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
|
31 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
32 |
+
# print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
33 |
+
opt['is_train'] = is_train
|
34 |
+
if opt['distortion'] == 'sr':
|
35 |
+
scale = opt['scale']
|
36 |
+
|
37 |
+
# datasets
|
38 |
+
for phase, dataset in opt['datasets'].items():
|
39 |
+
phase = phase.split('_')[0]
|
40 |
+
dataset['phase'] = phase
|
41 |
+
if opt['distortion'] == 'sr':
|
42 |
+
dataset['scale'] = scale
|
43 |
+
is_lmdb = False
|
44 |
+
if dataset.get('dataroot_GT', None) is not None:
|
45 |
+
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
|
46 |
+
if dataset['dataroot_GT'].endswith('lmdb'):
|
47 |
+
is_lmdb = True
|
48 |
+
if dataset.get('dataroot_LQ', None) is not None:
|
49 |
+
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
|
50 |
+
if dataset['dataroot_LQ'].endswith('lmdb'):
|
51 |
+
is_lmdb = True
|
52 |
+
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
|
53 |
+
if dataset['mode'].endswith('mc'): # for memcached
|
54 |
+
dataset['data_type'] = 'mc'
|
55 |
+
dataset['mode'] = dataset['mode'].replace('_mc', '')
|
56 |
+
|
57 |
+
# path
|
58 |
+
for key, path in opt['path'].items():
|
59 |
+
if path and key in opt['path'] and key != 'strict_load':
|
60 |
+
opt['path'][key] = osp.expanduser(path)
|
61 |
+
opt['path']['root'] = '/kaggle/working/'
|
62 |
+
if is_train:
|
63 |
+
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
|
64 |
+
opt['path']['experiments_root'] = experiments_root
|
65 |
+
opt['path']['models'] = osp.join(experiments_root, 'models')
|
66 |
+
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
|
67 |
+
opt['path']['log'] = experiments_root
|
68 |
+
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
|
69 |
+
|
70 |
+
# change some options for debug mode
|
71 |
+
if 'debug' in opt['name']:
|
72 |
+
opt['train']['val_freq'] = 8
|
73 |
+
opt['logger']['print_freq'] = 1
|
74 |
+
opt['logger']['save_checkpoint_freq'] = 8
|
75 |
+
else: # test
|
76 |
+
if not opt['path'].get('results_root', None):
|
77 |
+
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
78 |
+
opt['path']['results_root'] = results_root
|
79 |
+
opt['path']['log'] = opt['path']['results_root']
|
80 |
+
|
81 |
+
# network
|
82 |
+
if opt['distortion'] == 'sr':
|
83 |
+
opt['network_G']['scale'] = scale
|
84 |
+
|
85 |
+
# relative learning rate
|
86 |
+
if 'train' in opt:
|
87 |
+
niter = opt['train']['niter']
|
88 |
+
if 'T_period_rel' in opt['train']:
|
89 |
+
opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
|
90 |
+
if 'restarts_rel' in opt['train']:
|
91 |
+
opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
|
92 |
+
if 'lr_steps_rel' in opt['train']:
|
93 |
+
opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
|
94 |
+
if 'lr_steps_inverse_rel' in opt['train']:
|
95 |
+
opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
|
96 |
+
print(opt['train'])
|
97 |
+
|
98 |
+
return opt
|
99 |
+
|
100 |
+
|
101 |
+
def dict2str(opt, indent_l=1):
|
102 |
+
'''dict to string for logger'''
|
103 |
+
msg = ''
|
104 |
+
for k, v in opt.items():
|
105 |
+
if isinstance(v, dict):
|
106 |
+
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
107 |
+
msg += dict2str(v, indent_l + 1)
|
108 |
+
msg += ' ' * (indent_l * 2) + ']\n'
|
109 |
+
else:
|
110 |
+
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
111 |
+
return msg
|
112 |
+
|
113 |
+
|
114 |
+
class NoneDict(dict):
|
115 |
+
def __missing__(self, key):
|
116 |
+
return None
|
117 |
+
|
118 |
+
|
119 |
+
# convert to NoneDict, which return None for missing key.
|
120 |
+
def dict_to_nonedict(opt):
|
121 |
+
if isinstance(opt, dict):
|
122 |
+
new_opt = dict()
|
123 |
+
for key, sub_opt in opt.items():
|
124 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
125 |
+
return NoneDict(**new_opt)
|
126 |
+
elif isinstance(opt, list):
|
127 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
128 |
+
else:
|
129 |
+
return opt
|
130 |
+
|
131 |
+
|
132 |
+
def check_resume(opt, resume_iter):
|
133 |
+
'''Check resume states and pretrain_model paths'''
|
134 |
+
logger = logging.getLogger('base')
|
135 |
+
if opt['path']['resume_state']:
|
136 |
+
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
|
137 |
+
'pretrain_model_D', None) is not None:
|
138 |
+
logger.warning('pretrain_model path will be ignored when resuming training.')
|
139 |
+
|
140 |
+
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
141 |
+
'{}_G.pth'.format(resume_iter))
|
142 |
+
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
143 |
+
if 'gan' in opt['model']:
|
144 |
+
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
145 |
+
'{}_D.pth'.format(resume_iter))
|
146 |
+
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
models/SRFlow/code/prepare_data.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import glob
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import random
|
21 |
+
import imageio
|
22 |
+
import pickle
|
23 |
+
|
24 |
+
from natsort import natsort
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
def get_img_paths(dir_path, wildcard='*.png'):
|
28 |
+
return natsort.natsorted(glob.glob(dir_path + '/' + wildcard))
|
29 |
+
|
30 |
+
def create_all_dirs(path):
|
31 |
+
if "." in path.split("/")[-1]:
|
32 |
+
dirs = os.path.dirname(path)
|
33 |
+
else:
|
34 |
+
dirs = path
|
35 |
+
os.makedirs(dirs, exist_ok=True)
|
36 |
+
|
37 |
+
def to_pklv4(obj, path, vebose=False):
|
38 |
+
create_all_dirs(path)
|
39 |
+
with open(path, 'wb') as f:
|
40 |
+
pickle.dump(obj, f, protocol=4)
|
41 |
+
if vebose:
|
42 |
+
print("Wrote {}".format(path))
|
43 |
+
|
44 |
+
|
45 |
+
from imresize import imresize
|
46 |
+
|
47 |
+
def random_crop(img, size):
|
48 |
+
h, w, c = img.shape
|
49 |
+
|
50 |
+
h_start = np.random.randint(0, h - size)
|
51 |
+
h_end = h_start + size
|
52 |
+
|
53 |
+
w_start = np.random.randint(0, w - size)
|
54 |
+
w_end = w_start + size
|
55 |
+
|
56 |
+
return img[h_start:h_end, w_start:w_end]
|
57 |
+
|
58 |
+
|
59 |
+
def imread(img_path):
|
60 |
+
img = imageio.imread(img_path)
|
61 |
+
if len(img.shape) == 2:
|
62 |
+
img = np.stack([img, ] * 3, axis=2)
|
63 |
+
return img
|
64 |
+
|
65 |
+
|
66 |
+
def to_pklv4_1pct(obj, path, vebose):
|
67 |
+
n = int(round(len(obj) * 0.01))
|
68 |
+
path = path.replace(".", "_1pct.")
|
69 |
+
to_pklv4(obj[:n], path, vebose=True)
|
70 |
+
|
71 |
+
|
72 |
+
def main(dir_path):
|
73 |
+
hrs = []
|
74 |
+
lqs = []
|
75 |
+
|
76 |
+
img_paths = get_img_paths(dir_path)
|
77 |
+
for img_path in tqdm(img_paths):
|
78 |
+
img = imread(img_path)
|
79 |
+
|
80 |
+
for i in range(47):
|
81 |
+
crop = random_crop(img, 256)
|
82 |
+
cropX4 = imresize(crop, scalar_scale=0.25)
|
83 |
+
hrs.append(crop)
|
84 |
+
lqs.append(cropX4)
|
85 |
+
|
86 |
+
shuffle_combined(hrs, lqs)
|
87 |
+
|
88 |
+
hrs_path = get_hrs_path(dir_path)
|
89 |
+
to_pklv4(hrs, hrs_path, vebose=True)
|
90 |
+
|
91 |
+
lqs_path = get_lqs_path(dir_path)
|
92 |
+
to_pklv4(lqs, lqs_path, vebose=True)
|
93 |
+
|
94 |
+
|
95 |
+
def get_hrs_path(dir_path):
|
96 |
+
base_dir = '/kaggle/working/'
|
97 |
+
name = os.path.basename(dir_path)
|
98 |
+
hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4')
|
99 |
+
return hrs_path
|
100 |
+
|
101 |
+
|
102 |
+
def get_lqs_path(dir_path):
|
103 |
+
base_dir = '/kaggle/working/'
|
104 |
+
name = os.path.basename(dir_path)
|
105 |
+
hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4')
|
106 |
+
return hrs_path
|
107 |
+
|
108 |
+
|
109 |
+
def shuffle_combined(hrs, lqs):
|
110 |
+
combined = list(zip(hrs, lqs))
|
111 |
+
random.shuffle(combined)
|
112 |
+
hrs[:], lqs[:] = zip(*combined)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
dir_path = sys.argv[1]
|
117 |
+
assert os.path.isdir(dir_path)
|
118 |
+
main(dir_path)
|
models/SRFlow/code/test.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
|
18 |
+
import glob
|
19 |
+
import sys
|
20 |
+
from collections import OrderedDict
|
21 |
+
|
22 |
+
from natsort import natsort
|
23 |
+
|
24 |
+
import options.options as option
|
25 |
+
from Measure import Measure, psnr
|
26 |
+
from imresize import imresize
|
27 |
+
from models import create_model
|
28 |
+
import torch
|
29 |
+
from utils.util import opt_get
|
30 |
+
import numpy as np
|
31 |
+
import pandas as pd
|
32 |
+
import os
|
33 |
+
import cv2
|
34 |
+
|
35 |
+
|
36 |
+
def fiFindByWildcard(wildcard):
|
37 |
+
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
38 |
+
|
39 |
+
|
40 |
+
def load_model(conf_path):
|
41 |
+
opt = option.parse(conf_path, is_train=False)
|
42 |
+
opt['gpu_ids'] = None
|
43 |
+
opt = option.dict_to_nonedict(opt)
|
44 |
+
model = create_model(opt)
|
45 |
+
|
46 |
+
model_path = opt_get(opt, ['model_path'], None)
|
47 |
+
model.load_network(load_path=model_path, network=model.netG)
|
48 |
+
return model, opt
|
49 |
+
|
50 |
+
|
51 |
+
def predict(model, lr):
|
52 |
+
model.feed_data({"LQ": t(lr)}, need_GT=False)
|
53 |
+
model.test()
|
54 |
+
visuals = model.get_current_visuals(need_GT=False)
|
55 |
+
return visuals.get('rlt', visuals.get("SR"))
|
56 |
+
|
57 |
+
|
58 |
+
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
|
59 |
+
|
60 |
+
|
61 |
+
def rgb(t): return (
|
62 |
+
np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
|
63 |
+
np.uint8)
|
64 |
+
|
65 |
+
|
66 |
+
def imread(path):
|
67 |
+
return cv2.imread(path)[:, :, [2, 1, 0]]
|
68 |
+
|
69 |
+
|
70 |
+
def imwrite(path, img):
|
71 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
72 |
+
cv2.imwrite(path, img[:, :, [2, 1, 0]])
|
73 |
+
|
74 |
+
|
75 |
+
def imCropCenter(img, size):
|
76 |
+
h, w, c = img.shape
|
77 |
+
|
78 |
+
h_start = max(h // 2 - size // 2, 0)
|
79 |
+
h_end = min(h_start + size, h)
|
80 |
+
|
81 |
+
w_start = max(w // 2 - size // 2, 0)
|
82 |
+
w_end = min(w_start + size, w)
|
83 |
+
|
84 |
+
return img[h_start:h_end, w_start:w_end]
|
85 |
+
|
86 |
+
|
87 |
+
def impad(img, top=0, bottom=0, left=0, right=0, color=255):
|
88 |
+
return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
|
89 |
+
|
90 |
+
|
91 |
+
def main():
|
92 |
+
conf_path = sys.argv[1]
|
93 |
+
conf = conf_path.split('/')[-1].replace('.yml', '')
|
94 |
+
model, opt = load_model(conf_path)
|
95 |
+
|
96 |
+
data_dir = opt['dataroot']
|
97 |
+
|
98 |
+
# this_dir = os.path.dirname(os.path.realpath(__file__))
|
99 |
+
test_dir = os.path.join('/kaggle/working/', 'results', conf)
|
100 |
+
print(f"Out dir: {test_dir}")
|
101 |
+
|
102 |
+
measure = Measure(use_gpu=False)
|
103 |
+
|
104 |
+
fname = f'measure_full.csv'
|
105 |
+
fname_tmp = fname + "_"
|
106 |
+
path_out_measures = os.path.join(test_dir, fname_tmp)
|
107 |
+
path_out_measures_final = os.path.join(test_dir, fname)
|
108 |
+
|
109 |
+
if os.path.isfile(path_out_measures_final):
|
110 |
+
df = pd.read_csv(path_out_measures_final)
|
111 |
+
elif os.path.isfile(path_out_measures):
|
112 |
+
df = pd.read_csv(path_out_measures)
|
113 |
+
else:
|
114 |
+
df = None
|
115 |
+
|
116 |
+
scale = opt['scale']
|
117 |
+
|
118 |
+
pad_factor = 2
|
119 |
+
|
120 |
+
data_sets = [
|
121 |
+
'Set5',
|
122 |
+
'Set14',
|
123 |
+
'Urban100',
|
124 |
+
'BSD100'
|
125 |
+
]
|
126 |
+
|
127 |
+
final_df = pd.DataFrame()
|
128 |
+
|
129 |
+
for data_set in data_sets:
|
130 |
+
lr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*LR.png'))
|
131 |
+
hr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*HR.png'))
|
132 |
+
|
133 |
+
df = pd.DataFrame(columns=['conf', 'heat', 'data_set', 'name', 'PSNR', 'SSIM', 'LPIPS', 'LRC PSNR'])
|
134 |
+
|
135 |
+
for lr_path, hr_path, idx_test in zip(lr_paths, hr_paths, range(len(lr_paths))):
|
136 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
137 |
+
lr = imread(lr_path)
|
138 |
+
hr = imread(hr_path)
|
139 |
+
|
140 |
+
# Pad image to be % 2
|
141 |
+
h, w, c = lr.shape
|
142 |
+
lq_orig = lr.copy()
|
143 |
+
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
144 |
+
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
145 |
+
|
146 |
+
lr_t = t(lr)
|
147 |
+
|
148 |
+
heat = opt['heat']
|
149 |
+
|
150 |
+
if df is not None and len(df[(df['heat'] == heat) & (df['name'] == idx_test)]) == 1:
|
151 |
+
continue
|
152 |
+
|
153 |
+
sr_t = model.get_sr(lq=lr_t, heat=heat)
|
154 |
+
|
155 |
+
sr = rgb(torch.clamp(sr_t, 0, 1))
|
156 |
+
sr = sr[:h * scale, :w * scale]
|
157 |
+
|
158 |
+
path_out_sr = os.path.join(test_dir, data_set, "{:0.2f}".format(heat).replace('.', ''), "{:06d}.png".format(idx_test))
|
159 |
+
imwrite(path_out_sr, sr)
|
160 |
+
|
161 |
+
meas = OrderedDict(conf=conf, heat=heat, data_set=data_set, name=idx_test)
|
162 |
+
meas['PSNR'], meas['SSIM'], meas['LPIPS'] = measure.measure(sr, hr)
|
163 |
+
|
164 |
+
lr_reconstruct_rgb = imresize(sr, 1 / opt['scale'])
|
165 |
+
meas['LRC PSNR'] = psnr(lq_orig, lr_reconstruct_rgb)
|
166 |
+
|
167 |
+
str_out = format_measurements(meas)
|
168 |
+
print(str_out)
|
169 |
+
|
170 |
+
df = df._append(pd.DataFrame([meas]), ignore_index=True)
|
171 |
+
|
172 |
+
final_df = pd.concat([final_df, df])
|
173 |
+
|
174 |
+
final_df.to_csv(path_out_measures, index=False)
|
175 |
+
os.rename(path_out_measures, path_out_measures_final)
|
176 |
+
|
177 |
+
# str_out = format_measurements(df.mean())
|
178 |
+
# print(f"Results in: {path_out_measures_final}")
|
179 |
+
# print('Mean: ' + str_out)
|
180 |
+
|
181 |
+
|
182 |
+
def format_measurements(meas):
|
183 |
+
s_out = []
|
184 |
+
for k, v in meas.items():
|
185 |
+
v = f"{v:0.2f}" if isinstance(v, float) else v
|
186 |
+
s_out.append(f"{k}: {v}")
|
187 |
+
str_out = ", ".join(s_out)
|
188 |
+
return str_out
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
main()
|
models/SRFlow/code/train.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import os
|
18 |
+
from os.path import basename
|
19 |
+
import math
|
20 |
+
import argparse
|
21 |
+
import random
|
22 |
+
import logging
|
23 |
+
import cv2
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.distributed as dist
|
27 |
+
import torch.multiprocessing as mp
|
28 |
+
|
29 |
+
import options.options as option
|
30 |
+
from utils import util
|
31 |
+
from data import create_dataloader, create_dataset
|
32 |
+
from models import create_model
|
33 |
+
from utils.timer import Timer, TickTock
|
34 |
+
from utils.util import get_resume_paths
|
35 |
+
|
36 |
+
import wandb
|
37 |
+
|
38 |
+
def getEnv(name): import os; return True if name in os.environ.keys() else False
|
39 |
+
|
40 |
+
|
41 |
+
def init_dist(backend='nccl', **kwargs):
|
42 |
+
''' initialization for distributed training'''
|
43 |
+
# if mp.get_start_method(allow_none=True) is None:
|
44 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
45 |
+
mp.set_start_method('spawn')
|
46 |
+
rank = int(os.environ['RANK'])
|
47 |
+
num_gpus = torch.cuda.device_count()
|
48 |
+
torch.cuda.set_deviceDistIterSampler(rank % num_gpus)
|
49 |
+
dist.init_process_group(backend=backend, **kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
def main():
|
53 |
+
wandb.init(project='srflow')
|
54 |
+
#### options
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
|
57 |
+
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
58 |
+
help='job launcher')
|
59 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
60 |
+
args = parser.parse_args()
|
61 |
+
opt = option.parse(args.opt, is_train=True)
|
62 |
+
|
63 |
+
#### distributed training settings
|
64 |
+
opt['dist'] = False
|
65 |
+
rank = -1
|
66 |
+
print('Disabled distributed training.')
|
67 |
+
|
68 |
+
#### loading resume state if exists
|
69 |
+
if opt['path'].get('resume_state', None):
|
70 |
+
resume_state_path, _ = get_resume_paths(opt)
|
71 |
+
|
72 |
+
# distributed resuming: all load into default GPU
|
73 |
+
if resume_state_path is None:
|
74 |
+
resume_state = None
|
75 |
+
else:
|
76 |
+
device_id = torch.cuda.current_device()
|
77 |
+
resume_state = torch.load(resume_state_path,
|
78 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
79 |
+
option.check_resume(opt, resume_state['iter']) # check resume options
|
80 |
+
else:
|
81 |
+
resume_state = None
|
82 |
+
|
83 |
+
#### mkdir and loggers
|
84 |
+
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
85 |
+
if resume_state is None:
|
86 |
+
util.mkdir_and_rename(
|
87 |
+
opt['path']['experiments_root']) # rename experiment folder if exists
|
88 |
+
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
89 |
+
and 'pretrain_model' not in key and 'resume' not in key))
|
90 |
+
|
91 |
+
# config loggers. Before it, the log will not work
|
92 |
+
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
93 |
+
screen=True, tofile=True)
|
94 |
+
util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
|
95 |
+
screen=True, tofile=True)
|
96 |
+
logger = logging.getLogger('base')
|
97 |
+
logger.info(option.dict2str(opt))
|
98 |
+
|
99 |
+
# tensorboard logger
|
100 |
+
if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
|
101 |
+
version = float(torch.__version__[0:3])
|
102 |
+
if version >= 1.1: # PyTorch 1.1
|
103 |
+
from torch.utils.tensorboard import SummaryWriter
|
104 |
+
else:
|
105 |
+
logger.info(
|
106 |
+
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
107 |
+
from tensorboardX import SummaryWriter
|
108 |
+
conf_name = basename(args.opt).replace(".yml", "")
|
109 |
+
exp_dir = opt['path']['experiments_root']
|
110 |
+
log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
|
111 |
+
log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
|
112 |
+
tb_logger_train = SummaryWriter(log_dir=log_dir_train)
|
113 |
+
tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
|
114 |
+
else:
|
115 |
+
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
116 |
+
logger = logging.getLogger('base')
|
117 |
+
|
118 |
+
# convert to NoneDict, which returns None for missing keys
|
119 |
+
opt = option.dict_to_nonedict(opt)
|
120 |
+
|
121 |
+
#### random seed
|
122 |
+
seed = opt['train']['manual_seed']
|
123 |
+
if seed is None:
|
124 |
+
seed = random.randint(1, 10000)
|
125 |
+
if rank <= 0:
|
126 |
+
logger.info('Random seed: {}'.format(seed))
|
127 |
+
util.set_random_seed(seed)
|
128 |
+
|
129 |
+
torch.backends.cudnn.benchmark = True
|
130 |
+
# torch.backends.cudnn.deterministic = True
|
131 |
+
|
132 |
+
#### create train and val dataloader
|
133 |
+
dataset_ratio = 200 # enlarge the size of each epoch
|
134 |
+
for phase, dataset_opt in opt['datasets'].items():
|
135 |
+
if phase == 'train':
|
136 |
+
full_dataset = create_dataset(dataset_opt)
|
137 |
+
print('Dataset created')
|
138 |
+
train_len = int(len(full_dataset) * 0.95)
|
139 |
+
val_len = len(full_dataset) - train_len
|
140 |
+
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len])
|
141 |
+
train_size = int(math.ceil(train_len / dataset_opt['batch_size']))
|
142 |
+
total_iters = int(opt['train']['niter'])
|
143 |
+
total_epochs = int(math.ceil(total_iters / train_size))
|
144 |
+
train_sampler = None
|
145 |
+
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
146 |
+
if rank <= 0:
|
147 |
+
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
148 |
+
len(train_set), train_size))
|
149 |
+
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
150 |
+
total_epochs, total_iters))
|
151 |
+
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1,
|
152 |
+
pin_memory=True)
|
153 |
+
elif phase == 'val':
|
154 |
+
continue
|
155 |
+
else:
|
156 |
+
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
157 |
+
assert train_loader is not None
|
158 |
+
|
159 |
+
#### create model
|
160 |
+
current_step = 0 if resume_state is None else resume_state['iter']
|
161 |
+
model = create_model(opt, current_step)
|
162 |
+
|
163 |
+
#### resume training
|
164 |
+
if resume_state:
|
165 |
+
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
166 |
+
resume_state['epoch'], resume_state['iter']))
|
167 |
+
|
168 |
+
start_epoch = resume_state['epoch']
|
169 |
+
current_step = resume_state['iter']
|
170 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
171 |
+
else:
|
172 |
+
current_step = 0
|
173 |
+
start_epoch = 0
|
174 |
+
|
175 |
+
#### training
|
176 |
+
timer = Timer()
|
177 |
+
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
178 |
+
timerData = TickTock()
|
179 |
+
|
180 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
181 |
+
if opt['dist']:
|
182 |
+
train_sampler.set_epoch(epoch)
|
183 |
+
|
184 |
+
timerData.tick()
|
185 |
+
for _, train_data in enumerate(train_loader):
|
186 |
+
timerData.tock()
|
187 |
+
current_step += 1
|
188 |
+
if current_step > total_iters:
|
189 |
+
break
|
190 |
+
|
191 |
+
#### training
|
192 |
+
model.feed_data(train_data)
|
193 |
+
|
194 |
+
#### update learning rate
|
195 |
+
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
196 |
+
|
197 |
+
try:
|
198 |
+
nll = model.optimize_parameters(current_step)
|
199 |
+
except RuntimeError as e:
|
200 |
+
print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
|
201 |
+
print(e)
|
202 |
+
|
203 |
+
if nll is None:
|
204 |
+
nll = 0
|
205 |
+
|
206 |
+
wandb.log({"loss": nll})
|
207 |
+
#### log
|
208 |
+
def eta(t_iter):
|
209 |
+
return (t_iter * (opt['train']['niter'] - current_step)) / 3600
|
210 |
+
|
211 |
+
if current_step % opt['logger']['print_freq'] == 0 \
|
212 |
+
or current_step - (resume_state['iter'] if resume_state else 0) < 25:
|
213 |
+
avg_time = timer.get_average_and_reset()
|
214 |
+
avg_data_time = timerData.get_average_and_reset()
|
215 |
+
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
|
216 |
+
epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
|
217 |
+
eta(avg_time), nll)
|
218 |
+
print(message)
|
219 |
+
timer.tick()
|
220 |
+
# Reduce number of logs
|
221 |
+
if current_step % 5 == 0:
|
222 |
+
tb_logger_train.add_scalar('loss/nll', nll, current_step)
|
223 |
+
tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
|
224 |
+
tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
|
225 |
+
tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
|
226 |
+
tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
|
227 |
+
for k, v in model.get_current_log().items():
|
228 |
+
tb_logger_train.add_scalar(k, v, current_step)
|
229 |
+
|
230 |
+
# validation
|
231 |
+
if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
|
232 |
+
avg_psnr = 0.0
|
233 |
+
idx = 0
|
234 |
+
nlls = []
|
235 |
+
for val_data in val_loader:
|
236 |
+
idx += 1
|
237 |
+
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
|
238 |
+
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
239 |
+
util.mkdir(img_dir)
|
240 |
+
|
241 |
+
model.feed_data(val_data)
|
242 |
+
|
243 |
+
nll = model.test()
|
244 |
+
if nll is None:
|
245 |
+
nll = 0
|
246 |
+
nlls.append(nll)
|
247 |
+
|
248 |
+
visuals = model.get_current_visuals()
|
249 |
+
|
250 |
+
sr_img = None
|
251 |
+
# Save SR images for reference
|
252 |
+
if hasattr(model, 'heats'):
|
253 |
+
for heat in model.heats:
|
254 |
+
for i in range(model.n_sample):
|
255 |
+
sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8
|
256 |
+
save_img_path = os.path.join(img_dir,
|
257 |
+
'{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
|
258 |
+
current_step,
|
259 |
+
int(heat * 100), i))
|
260 |
+
util.save_img(sr_img, save_img_path)
|
261 |
+
else:
|
262 |
+
sr_img = util.tensor2img(visuals['SR']) # uint8
|
263 |
+
save_img_path = os.path.join(img_dir,
|
264 |
+
'{:s}_{:d}.png'.format(img_name, current_step))
|
265 |
+
util.save_img(sr_img, save_img_path)
|
266 |
+
assert sr_img is not None
|
267 |
+
|
268 |
+
# Save LQ images for reference
|
269 |
+
save_img_path_lq = os.path.join(img_dir,
|
270 |
+
'{:s}_LQ.png'.format(img_name))
|
271 |
+
if not os.path.isfile(save_img_path_lq):
|
272 |
+
lq_img = util.tensor2img(visuals['LQ']) # uint8
|
273 |
+
util.save_img(
|
274 |
+
cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
|
275 |
+
interpolation=cv2.INTER_NEAREST),
|
276 |
+
save_img_path_lq)
|
277 |
+
|
278 |
+
# Save GT images for reference
|
279 |
+
gt_img = util.tensor2img(visuals['GT']) # uint8
|
280 |
+
save_img_path_gt = os.path.join(img_dir,
|
281 |
+
'{:s}_GT.png'.format(img_name))
|
282 |
+
if not os.path.isfile(save_img_path_gt):
|
283 |
+
util.save_img(gt_img, save_img_path_gt)
|
284 |
+
|
285 |
+
# calculate PSNR
|
286 |
+
crop_size = opt['scale']
|
287 |
+
gt_img = gt_img / 255.
|
288 |
+
sr_img = sr_img / 255.
|
289 |
+
cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
|
290 |
+
cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
|
291 |
+
avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
|
292 |
+
|
293 |
+
avg_psnr = avg_psnr / idx
|
294 |
+
avg_nll = sum(nlls) / len(nlls)
|
295 |
+
|
296 |
+
# log
|
297 |
+
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
298 |
+
logger_val = logging.getLogger('val') # validation logger
|
299 |
+
logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
|
300 |
+
epoch, current_step, avg_psnr))
|
301 |
+
|
302 |
+
# tensorboard logger
|
303 |
+
tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
|
304 |
+
tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)
|
305 |
+
|
306 |
+
tb_logger_train.flush()
|
307 |
+
tb_logger_valid.flush()
|
308 |
+
|
309 |
+
#### save models and training states
|
310 |
+
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
311 |
+
if rank <= 0:
|
312 |
+
logger.info('Saving models and training states.')
|
313 |
+
model.save(current_step)
|
314 |
+
model.save_training_state(epoch, current_step)
|
315 |
+
|
316 |
+
timerData.tick()
|
317 |
+
|
318 |
+
with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
|
319 |
+
f.write("TRAIN_DONE")
|
320 |
+
|
321 |
+
if rank <= 0:
|
322 |
+
logger.info('Saving the final model.')
|
323 |
+
model.save('latest')
|
324 |
+
logger.info('End of training.')
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
main()
|
models/SRFlow/code/utils/__init__.py
ADDED
File without changes
|
models/SRFlow/code/utils/timer.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
+
|
17 |
+
import time
|
18 |
+
|
19 |
+
|
20 |
+
class ScopeTimer:
|
21 |
+
def __init__(self, name):
|
22 |
+
self.name = name
|
23 |
+
|
24 |
+
def __enter__(self):
|
25 |
+
self.start = time.time()
|
26 |
+
return self
|
27 |
+
|
28 |
+
def __exit__(self, *args):
|
29 |
+
self.end = time.time()
|
30 |
+
self.interval = self.end - self.start
|
31 |
+
print("{} {:.3E}".format(self.name, self.interval))
|
32 |
+
|
33 |
+
|
34 |
+
class Timer:
|
35 |
+
def __init__(self):
|
36 |
+
self.times = []
|
37 |
+
|
38 |
+
def tick(self):
|
39 |
+
self.times.append(time.time())
|
40 |
+
|
41 |
+
def get_average_and_reset(self):
|
42 |
+
if len(self.times) < 2:
|
43 |
+
return -1
|
44 |
+
avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1)
|
45 |
+
self.times = [self.times[-1]]
|
46 |
+
return avg
|
47 |
+
|
48 |
+
def get_last_iteration(self):
|
49 |
+
if len(self.times) < 2:
|
50 |
+
return 0
|
51 |
+
return self.times[-1] - self.times[-2]
|
52 |
+
|
53 |
+
|
54 |
+
class TickTock:
|
55 |
+
def __init__(self):
|
56 |
+
self.time_pairs = []
|
57 |
+
self.current_time = None
|
58 |
+
|
59 |
+
def tick(self):
|
60 |
+
self.current_time = time.time()
|
61 |
+
|
62 |
+
def tock(self):
|
63 |
+
assert self.current_time is not None, self.current_time
|
64 |
+
self.time_pairs.append([self.current_time, time.time()])
|
65 |
+
self.current_time = None
|
66 |
+
|
67 |
+
def get_average_and_reset(self):
|
68 |
+
if len(self.time_pairs) == 0:
|
69 |
+
return -1
|
70 |
+
deltas = [t2 - t1 for t1, t2 in self.time_pairs]
|
71 |
+
avg = sum(deltas) / len(deltas)
|
72 |
+
self.time_pairs = []
|
73 |
+
return avg
|
74 |
+
|
75 |
+
def get_last_iteration(self):
|
76 |
+
if len(self.time_pairs) == 0:
|
77 |
+
return -1
|
78 |
+
return self.time_pairs[-1][1] - self.time_pairs[-1][0]
|
models/SRFlow/code/utils/util.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import math
|
6 |
+
from datetime import datetime
|
7 |
+
import random
|
8 |
+
import logging
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
import natsort
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
import torch
|
15 |
+
from torchvision.utils import make_grid
|
16 |
+
from shutil import get_terminal_size
|
17 |
+
|
18 |
+
import yaml
|
19 |
+
|
20 |
+
try:
|
21 |
+
from yaml import CLoader as Loader, CDumper as Dumper
|
22 |
+
except ImportError:
|
23 |
+
from yaml import Loader, Dumper
|
24 |
+
|
25 |
+
|
26 |
+
def OrderedYaml():
|
27 |
+
'''yaml orderedDict support'''
|
28 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
29 |
+
|
30 |
+
def dict_representer(dumper, data):
|
31 |
+
return dumper.represent_dict(data.items())
|
32 |
+
|
33 |
+
def dict_constructor(loader, node):
|
34 |
+
return OrderedDict(loader.construct_pairs(node))
|
35 |
+
|
36 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
37 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
38 |
+
return Loader, Dumper
|
39 |
+
|
40 |
+
|
41 |
+
####################
|
42 |
+
# miscellaneous
|
43 |
+
####################
|
44 |
+
|
45 |
+
|
46 |
+
def get_timestamp():
|
47 |
+
return datetime.now().strftime('%y%m%d-%H%M%S')
|
48 |
+
|
49 |
+
|
50 |
+
def mkdir(path):
|
51 |
+
if not os.path.exists(path):
|
52 |
+
os.makedirs(path)
|
53 |
+
|
54 |
+
|
55 |
+
def mkdirs(paths):
|
56 |
+
if isinstance(paths, str):
|
57 |
+
mkdir(paths)
|
58 |
+
else:
|
59 |
+
for path in paths:
|
60 |
+
mkdir(path)
|
61 |
+
|
62 |
+
|
63 |
+
def mkdir_and_rename(path):
|
64 |
+
if os.path.exists(path):
|
65 |
+
new_name = path + '_archived_' + get_timestamp()
|
66 |
+
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
67 |
+
logger = logging.getLogger('base')
|
68 |
+
logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
|
69 |
+
os.rename(path, new_name)
|
70 |
+
os.makedirs(path)
|
71 |
+
|
72 |
+
|
73 |
+
def set_random_seed(seed):
|
74 |
+
random.seed(seed)
|
75 |
+
np.random.seed(seed)
|
76 |
+
torch.manual_seed(seed)
|
77 |
+
torch.cuda.manual_seed_all(seed)
|
78 |
+
|
79 |
+
|
80 |
+
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
|
81 |
+
'''set up logger'''
|
82 |
+
lg = logging.getLogger(logger_name)
|
83 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
|
84 |
+
datefmt='%y-%m-%d %H:%M:%S')
|
85 |
+
lg.setLevel(level)
|
86 |
+
if tofile:
|
87 |
+
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
|
88 |
+
fh = logging.FileHandler(log_file, mode='w')
|
89 |
+
fh.setFormatter(formatter)
|
90 |
+
lg.addHandler(fh)
|
91 |
+
if screen:
|
92 |
+
sh = logging.StreamHandler()
|
93 |
+
sh.setFormatter(formatter)
|
94 |
+
lg.addHandler(sh)
|
95 |
+
|
96 |
+
|
97 |
+
####################
|
98 |
+
# image convert
|
99 |
+
####################
|
100 |
+
|
101 |
+
|
102 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
103 |
+
'''
|
104 |
+
Converts a torch Tensor into an image Numpy array
|
105 |
+
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
106 |
+
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
107 |
+
'''
|
108 |
+
if hasattr(tensor, 'detach'):
|
109 |
+
tensor = tensor.detach()
|
110 |
+
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
|
111 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
112 |
+
n_dim = tensor.dim()
|
113 |
+
if n_dim == 4:
|
114 |
+
n_img = len(tensor)
|
115 |
+
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
116 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
117 |
+
elif n_dim == 3:
|
118 |
+
img_np = tensor.numpy()
|
119 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
120 |
+
elif n_dim == 2:
|
121 |
+
img_np = tensor.numpy()
|
122 |
+
else:
|
123 |
+
raise TypeError(
|
124 |
+
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
125 |
+
if out_type == np.uint8:
|
126 |
+
img_np = (img_np * 255.0).round()
|
127 |
+
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
128 |
+
return img_np.astype(out_type)
|
129 |
+
|
130 |
+
|
131 |
+
def save_img(img, img_path, mode='RGB'):
|
132 |
+
cv2.imwrite(img_path, img)
|
133 |
+
|
134 |
+
|
135 |
+
####################
|
136 |
+
# metric
|
137 |
+
####################
|
138 |
+
|
139 |
+
|
140 |
+
def calculate_psnr(img1, img2):
|
141 |
+
# img1 and img2 have range [0, 255]
|
142 |
+
img1 = img1.astype(np.float64)
|
143 |
+
img2 = img2.astype(np.float64)
|
144 |
+
mse = np.mean((img1 - img2) ** 2)
|
145 |
+
if mse == 0:
|
146 |
+
return float('inf')
|
147 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
148 |
+
|
149 |
+
|
150 |
+
def get_resume_paths(opt):
|
151 |
+
resume_state_path = None
|
152 |
+
resume_model_path = None
|
153 |
+
ts = opt_get(opt, ['path', 'training_state'])
|
154 |
+
if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
|
155 |
+
wildcard = os.path.join(ts, "*")
|
156 |
+
paths = natsort.natsorted(glob.glob(wildcard))
|
157 |
+
if len(paths) > 0:
|
158 |
+
resume_state_path = paths[-1]
|
159 |
+
resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth')
|
160 |
+
else:
|
161 |
+
resume_state_path = opt.get('path', {}).get('resume_state')
|
162 |
+
return resume_state_path, resume_model_path
|
163 |
+
|
164 |
+
|
165 |
+
def opt_get(opt, keys, default=None):
|
166 |
+
if opt is None:
|
167 |
+
return default
|
168 |
+
ret = opt
|
169 |
+
for k in keys:
|
170 |
+
ret = ret.get(k, None)
|
171 |
+
if ret is None:
|
172 |
+
return default
|
173 |
+
return ret
|
174 |
+
|
models/SRFlow/srflow.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from code.test import imread, impad, t, load_model, rgb
|
5 |
+
|
6 |
+
def return_SRFlow_result(lr_path, conf_path='/models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
|
7 |
+
model, opt = load_model(conf_path)
|
8 |
+
lr = imread(lr_path)
|
9 |
+
|
10 |
+
scale = opt['scale']
|
11 |
+
pad_factor = 2
|
12 |
+
|
13 |
+
h, w, c = lr.shape
|
14 |
+
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
15 |
+
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
16 |
+
|
17 |
+
lr_t = t(lr)
|
18 |
+
heat = opt[heat]
|
19 |
+
|
20 |
+
sr_t = model.get_sr(lq=lr_t, heat=heat)
|
21 |
+
|
22 |
+
sr = rgb(torch.clamp(sr_t, 0, 1))
|
23 |
+
sr = sr[:h * scale, :w * scale]
|
24 |
+
|
25 |
+
return sr
|
26 |
+
|
27 |
+
|