add SRFlow with srflow.py

#2
Files changed (42) hide show
  1. models/SRFlow/35000_G.pth +3 -0
  2. models/SRFlow/code/Measure.py +134 -0
  3. models/SRFlow/code/a.py +27 -0
  4. models/SRFlow/code/confs/RRDB_CelebA_8X.yml +83 -0
  5. models/SRFlow/code/confs/RRDB_DF2K_4X.yml +85 -0
  6. models/SRFlow/code/confs/RRDB_DF2K_8X.yml +85 -0
  7. models/SRFlow/code/confs/SRFlow_CelebA_8X.yml +107 -0
  8. models/SRFlow/code/confs/SRFlow_DF2K_4X.yml +106 -0
  9. models/SRFlow/code/confs/SRFlow_DF2K_8X.yml +112 -0
  10. models/SRFlow/code/data/LRHR_PKL_dataset.py +179 -0
  11. models/SRFlow/code/data/__init__.py +51 -0
  12. models/SRFlow/code/demo_on_pretrained.ipynb +0 -0
  13. models/SRFlow/code/imresize.py +180 -0
  14. models/SRFlow/code/models/SRFlow_model.py +278 -0
  15. models/SRFlow/code/models/SR_model.py +217 -0
  16. models/SRFlow/code/models/__init__.py +52 -0
  17. models/SRFlow/code/models/base_model.py +154 -0
  18. models/SRFlow/code/models/lr_scheduler.py +163 -0
  19. models/SRFlow/code/models/modules/FlowActNorms.py +141 -0
  20. models/SRFlow/code/models/modules/FlowAffineCouplingsAblation.py +135 -0
  21. models/SRFlow/code/models/modules/FlowStep.py +137 -0
  22. models/SRFlow/code/models/modules/FlowUpsamplerNet.py +309 -0
  23. models/SRFlow/code/models/modules/Permutations.py +58 -0
  24. models/SRFlow/code/models/modules/RRDBNet_arch.py +148 -0
  25. models/SRFlow/code/models/modules/SRFlowNet_arch.py +158 -0
  26. models/SRFlow/code/models/modules/Split.py +86 -0
  27. models/SRFlow/code/models/modules/__init__.py +0 -0
  28. models/SRFlow/code/models/modules/flow.py +166 -0
  29. models/SRFlow/code/models/modules/glow_arch.py +28 -0
  30. models/SRFlow/code/models/modules/loss.py +90 -0
  31. models/SRFlow/code/models/modules/module_util.py +95 -0
  32. models/SRFlow/code/models/modules/thops.py +68 -0
  33. models/SRFlow/code/models/networks.py +105 -0
  34. models/SRFlow/code/options/__init__.py +0 -0
  35. models/SRFlow/code/options/options.py +146 -0
  36. models/SRFlow/code/prepare_data.py +118 -0
  37. models/SRFlow/code/test.py +192 -0
  38. models/SRFlow/code/train.py +328 -0
  39. models/SRFlow/code/utils/__init__.py +0 -0
  40. models/SRFlow/code/utils/timer.py +78 -0
  41. models/SRFlow/code/utils/util.py +174 -0
  42. models/SRFlow/srflow.py +27 -0
models/SRFlow/35000_G.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:040fcffde66ec3ef658a843d58832b8aa153734a2f04342d841e3018b498a511
3
+ size 158819348
models/SRFlow/code/Measure.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ import os
17
+ import time
18
+ from collections import OrderedDict
19
+
20
+ import numpy as np
21
+ import torch
22
+ import cv2
23
+ import argparse
24
+
25
+ from natsort import natsort
26
+ from skimage.metrics import structural_similarity as ssim
27
+ from skimage.metrics import peak_signal_noise_ratio as psnr
28
+ import lpips
29
+
30
+
31
+ class Measure():
32
+ def __init__(self, net='alex', use_gpu=False):
33
+ self.device = 'cuda' if use_gpu else 'cpu'
34
+ self.model = lpips.LPIPS(net=net)
35
+ self.model.to(self.device)
36
+
37
+ def measure(self, imgA, imgB):
38
+ return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
39
+
40
+ def lpips(self, imgA, imgB, model=None):
41
+ tA = t(imgA).to(self.device)
42
+ tB = t(imgB).to(self.device)
43
+ dist01 = self.model.forward(tA, tB).item()
44
+ return dist01
45
+
46
+ def ssim(self, imgA, imgB):
47
+ # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
48
+ score, diff = ssim(imgA, imgB, full=True, multichannel=True, channel_axis=-1)
49
+ return score
50
+
51
+ def psnr(self, imgA, imgB):
52
+ psnr_val = psnr(imgA, imgB)
53
+ return psnr_val
54
+
55
+
56
+ def t(img):
57
+ def to_4d(img):
58
+ assert len(img.shape) == 3
59
+ assert img.dtype == np.uint8
60
+ img_new = np.expand_dims(img, axis=0)
61
+ assert len(img_new.shape) == 4
62
+ return img_new
63
+
64
+ def to_CHW(img):
65
+ return np.transpose(img, [2, 0, 1])
66
+
67
+ def to_tensor(img):
68
+ return torch.Tensor(img)
69
+
70
+ return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
71
+
72
+
73
+ def fiFindByWildcard(wildcard):
74
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
75
+
76
+
77
+ def imread(path):
78
+ return cv2.imread(path)[:, :, [2, 1, 0]]
79
+
80
+
81
+ def format_result(psnr, ssim, lpips):
82
+ return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
83
+
84
+ def measure_dirs(dirA, dirB, use_gpu, verbose=False):
85
+ if verbose:
86
+ vprint = lambda x: print(x)
87
+ else:
88
+ vprint = lambda x: None
89
+
90
+
91
+ t_init = time.time()
92
+
93
+ paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
94
+ paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
95
+
96
+ vprint("Comparing: ")
97
+ vprint(dirA)
98
+ vprint(dirB)
99
+
100
+ measure = Measure(use_gpu=use_gpu)
101
+
102
+ results = []
103
+ for pathA, pathB in zip(paths_A, paths_B):
104
+ result = OrderedDict()
105
+
106
+ t = time.time()
107
+ result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
108
+ d = time.time() - t
109
+ vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
110
+
111
+ results.append(result)
112
+
113
+ psnr = np.mean([result['psnr'] for result in results])
114
+ ssim = np.mean([result['ssim'] for result in results])
115
+ lpips = np.mean([result['lpips'] for result in results])
116
+
117
+ vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('-dirA', default='', type=str)
123
+ parser.add_argument('-dirB', default='', type=str)
124
+ parser.add_argument('-type', default='png')
125
+ parser.add_argument('--use_gpu', action='store_true', default=False)
126
+ args = parser.parse_args()
127
+
128
+ dirA = args.dirA
129
+ dirB = args.dirB
130
+ type = args.type
131
+ use_gpu = args.use_gpu
132
+
133
+ if len(dirA) > 0 and len(dirB) > 0:
134
+ measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
models/SRFlow/code/a.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+
6
+ def load_pkls(path):
7
+ assert os.path.isfile(path), path
8
+ images = []
9
+ with open(path, "rb") as f:
10
+ images += pickle.load(f)
11
+ assert len(images) > 0, path
12
+ images = [np.transpose(image, [2, 0, 1]) for image in images]
13
+ return images
14
+
15
+ path = 'datasets/DIV2K-va.pklv4'
16
+ loaded_images = load_pkls(path)
17
+ print(len(loaded_images))
18
+ # Display the first image
19
+ if loaded_images:
20
+ first_image = loaded_images[11]
21
+ plt.imshow(np.transpose(first_image, [1, 2, 0])) # Transpose image to original shape [height, width, channels]
22
+ plt.title('First Image')
23
+ plt.axis('off') # Hide axis
24
+ plt.show()
25
+ else:
26
+ print("No images loaded from the pickle file.")
27
+ print(loaded_images[11])
models/SRFlow/code/confs/RRDB_CelebA_8X.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SR
21
+ distortion: sr
22
+ scale: 8
23
+ #gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: CelebA_160_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: ../datasets/celebA-train-gt_1pct.pklv4
31
+ dataroot_LQ: ../datasets/celebA-train-x8_1pct.pklv4
32
+
33
+ use_shuffle: true
34
+ n_workers: 0 # per GPU
35
+ batch_size: 16
36
+ GT_size: 160
37
+ use_flip: true
38
+ use_rot: true
39
+ color: RGB
40
+ val:
41
+ name: CelebA_160_va
42
+ mode: LRHR_PKL
43
+ dataroot_GT: ../datasets/celebA-valid-gt_1pct.pklv4
44
+ dataroot_LQ: ../datasets/celebA-valid-x8_1pct.pklv4
45
+ n_max: 10
46
+
47
+ #### network structures
48
+ network_G:
49
+ which_model_G: RRDBNet
50
+ in_nc: 3
51
+ out_nc: 3
52
+ nf: 64
53
+ nb: 23
54
+
55
+ #### path
56
+ path:
57
+ pretrain_model_G: ~
58
+ strict_load: true
59
+ resume_state: auto
60
+
61
+ #### training settings: learning rate scheme, loss
62
+ train:
63
+ lr_G: !!float 2e-4
64
+ lr_scheme: CosineAnnealingLR_Restart
65
+ beta1: 0.9
66
+ beta2: 0.99
67
+ niter: 200000
68
+ warmup_iter: -1 # no warm up
69
+ T_period: [ 50000, 50000, 50000, 50000 ]
70
+ restarts: [ 50000, 100000, 150000 ]
71
+ restart_weights: [ 1, 1, 1 ]
72
+ eta_min: !!float 1e-7
73
+
74
+ pixel_criterion: l1
75
+ pixel_weight: 1.0
76
+
77
+ manual_seed: 10
78
+ val_freq: !!float 5e3
79
+
80
+ #### logger
81
+ logger:
82
+ print_freq: 100
83
+ save_checkpoint_freq: !!float 1e3
models/SRFlow/code/confs/RRDB_DF2K_4X.yml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SR
21
+ distortion: sr
22
+ scale: 4
23
+ gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: CelebA_160_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
31
+ dataroot_LQ: ../datasets/DF2K-train-x4_1pct.pklv4
32
+ quant: 32
33
+
34
+ use_shuffle: true
35
+ n_workers: 3 # per GPU
36
+ batch_size: 16
37
+ GT_size: 160
38
+ use_flip: true
39
+ color: RGB
40
+ val:
41
+ name: CelebA_160_va
42
+ mode: LRHR_PKL
43
+ dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
44
+ dataroot_LQ: ../datasets/DF2K-valid-x4_1pct.pklv4
45
+ quant: 32
46
+ n_max: 20
47
+
48
+ #### network structures
49
+ network_G:
50
+ which_model_G: RRDBNet
51
+ use_orig: True
52
+ in_nc: 3
53
+ out_nc: 3
54
+ nf: 64
55
+ nb: 23
56
+
57
+ #### path
58
+ path:
59
+ pretrain_model_G: ~
60
+ strict_load: true
61
+ resume_state: auto
62
+
63
+ #### training settings: learning rate scheme, loss
64
+ train:
65
+ lr_G: !!float 2e-4
66
+ lr_scheme: CosineAnnealingLR_Restart
67
+ beta1: 0.9
68
+ beta2: 0.99
69
+ niter: 1000000
70
+ warmup_iter: -1 # no warm up
71
+ T_period: [ 50000, 50000, 50000, 50000 ]
72
+ restarts: [ 50000, 100000, 150000 ]
73
+ restart_weights: [ 1, 1, 1 ]
74
+ eta_min: !!float 1e-7
75
+
76
+ pixel_criterion: l1
77
+ pixel_weight: 1.0
78
+
79
+ manual_seed: 10
80
+ val_freq: !!float 5e3
81
+
82
+ #### logger
83
+ logger:
84
+ print_freq: 100
85
+ save_checkpoint_freq: !!float 1e3
models/SRFlow/code/confs/RRDB_DF2K_8X.yml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SR
21
+ distortion: sr
22
+ scale: 8
23
+ gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: CelebA_160_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
31
+ dataroot_LQ: ../datasets/DF2K-train-x8_1pct.pklv4
32
+ quant: 32
33
+
34
+ use_shuffle: true
35
+ n_workers: 3 # per GPU
36
+ batch_size: 16
37
+ GT_size: 160
38
+ use_flip: true
39
+ color: RGB
40
+
41
+ val:
42
+ name: CelebA_160_va
43
+ mode: LRHR_PKL
44
+ dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
45
+ dataroot_LQ: ../datasets/DF2K-valid-x8_1pct.pklv4
46
+ quant: 32
47
+ n_max: 20
48
+
49
+ #### network structures
50
+ network_G:
51
+ which_model_G: RRDBNet
52
+ in_nc: 3
53
+ out_nc: 3
54
+ nf: 64
55
+ nb: 23
56
+
57
+ #### path
58
+ path:
59
+ pretrain_model_G: ~
60
+ strict_load: true
61
+ resume_state: auto
62
+
63
+ #### training settings: learning rate scheme, loss
64
+ train:
65
+ lr_G: !!float 2e-4
66
+ lr_scheme: CosineAnnealingLR_Restart
67
+ beta1: 0.9
68
+ beta2: 0.99
69
+ niter: 200000
70
+ warmup_iter: -1 # no warm up
71
+ T_period: [ 50000, 50000, 50000, 50000 ]
72
+ restarts: [ 50000, 100000, 150000 ]
73
+ restart_weights: [ 1, 1, 1 ]
74
+ eta_min: !!float 1e-7
75
+
76
+ pixel_criterion: l1
77
+ pixel_weight: 1.0
78
+
79
+ manual_seed: 10
80
+ val_freq: !!float 5e3
81
+
82
+ #### logger
83
+ logger:
84
+ print_freq: 100
85
+ save_checkpoint_freq: !!float 1e3
models/SRFlow/code/confs/SRFlow_CelebA_8X.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SRFlow
21
+ distortion: sr
22
+ scale: 8
23
+ gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: CelebA_160_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: ../datasets/celebA-train-gt.pklv4
31
+ dataroot_LQ: ../datasets/celebA-train-x8.pklv4
32
+ quant: 32
33
+
34
+ use_shuffle: true
35
+ n_workers: 3 # per GPU
36
+ batch_size: 16
37
+ GT_size: 160
38
+ use_flip: true
39
+ color: RGB
40
+ val:
41
+ name: CelebA_160_va
42
+ mode: LRHR_PKL
43
+ dataroot_GT: ../datasets/celebA-train-gt.pklv4
44
+ dataroot_LQ: ../datasets/celebA-train-x8.pklv4
45
+ quant: 32
46
+ n_max: 20
47
+
48
+ #### Test Settings
49
+ dataroot_GT: ../datasets/celebA-validation-gt
50
+ dataroot_LR: ../datasets/celebA-validation-x8
51
+ model_path: ../pretrained_models/SRFlow_CelebA_8X.pth
52
+ heat: 0.9 # This is the standard deviation of the latent vectors
53
+
54
+ #### network structures
55
+ network_G:
56
+ which_model_G: SRFlowNet
57
+ in_nc: 3
58
+ out_nc: 3
59
+ nf: 64
60
+ nb: 8
61
+ upscale: 8
62
+ train_RRDB: false
63
+ train_RRDB_delay: 0.5
64
+
65
+ flow:
66
+ K: 16
67
+ L: 4
68
+ noInitialInj: true
69
+ coupling: CondAffineSeparatedAndCond
70
+ additionalFlowNoAffine: 2
71
+ split:
72
+ enable: true
73
+ fea_up0: true
74
+ stackRRDB:
75
+ blocks: [ 1, 3, 5, 7 ]
76
+ concat: true
77
+
78
+ #### path
79
+ path:
80
+ pretrain_model_G: ../pretrained_models/RRDB_CelebA_8X.pth
81
+ strict_load: true
82
+ resume_state: auto
83
+
84
+ #### training settings: learning rate scheme, loss
85
+ train:
86
+ manual_seed: 10
87
+ lr_G: !!float 5e-4
88
+ weight_decay_G: 0
89
+ beta1: 0.9
90
+ beta2: 0.99
91
+ lr_scheme: MultiStepLR
92
+ warmup_iter: -1 # no warm up
93
+ lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
94
+ lr_gamma: 0.5
95
+
96
+ niter: 200000
97
+ val_freq: 40000
98
+
99
+ #### validation settings
100
+ val:
101
+ heats: [ 0.0, 0.5, 0.75, 1.0 ]
102
+ n_sample: 3
103
+
104
+ #### logger
105
+ logger:
106
+ print_freq: 100
107
+ save_checkpoint_freq: !!float 1e3
models/SRFlow/code/confs/SRFlow_DF2K_4X.yml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SRFlow
21
+ distortion: sr
22
+ scale: 4
23
+ gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: DF2K_256_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: /kaggle/input/srflow0103/SRFlow/datasets/DF2K-tr.pklv4
31
+ dataroot_LQ: /kaggle/input/srflow0103/SRFlow/datasets/DF2K-tr_X4.pklv4
32
+ quant: 32
33
+
34
+ use_shuffle: true
35
+ n_workers: 3 # per GPU
36
+ batch_size: 12
37
+ GT_size: 256
38
+ use_flip: true
39
+ color: RGB
40
+ val:
41
+ name: DF2K_256_tr
42
+ mode: LRHR_PKL
43
+ dataroot_GT: ../datasets/DIV2K-va.pklv4
44
+ dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4
45
+ quant: 32
46
+ n_max: 20
47
+
48
+ #### Test Settings
49
+ dataroot: /kaggle/input/test-set/test set
50
+ model_path: /models/SRFlow/35000_G
51
+ heat: 0.6 # This is the standard deviation of the latent vectors
52
+
53
+ #### network structures
54
+ network_G:
55
+ which_model_G: SRFlowNet
56
+ in_nc: 3
57
+ out_nc: 3
58
+ nf: 64
59
+ nb: 23
60
+ upscale: 4
61
+ train_RRDB: false
62
+ train_RRDB_delay: 0.5
63
+
64
+ flow:
65
+ K: 16
66
+ L: 3
67
+ noInitialInj: true
68
+ coupling: CondAffineSeparatedAndCond
69
+ additionalFlowNoAffine: 2
70
+ split:
71
+ enable: true
72
+ fea_up0: true
73
+ stackRRDB:
74
+ blocks: [ 1, 8, 15, 22 ]
75
+ concat: true
76
+
77
+ #### path
78
+ path:
79
+ pretrain_model_G:
80
+ strict_load: true
81
+ resume_state: auto
82
+
83
+ #### training settings: learning rate scheme, loss
84
+ train:
85
+ manual_seed: 10
86
+ lr_G: !!float 2.5e-4
87
+ weight_decay_G: 0
88
+ beta1: 0.9
89
+ beta2: 0.99
90
+ lr_scheme: MultiStepLR
91
+ warmup_iter: -1 # no warm up
92
+ lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
93
+ lr_gamma: 0.5
94
+
95
+ niter: 64185
96
+ val_freq: 40000
97
+
98
+ #### validation settings
99
+ val:
100
+ heats: [ 0.0, 0.5, 0.75, 1.0 ]
101
+ n_sample: 3
102
+
103
+ #### logger
104
+ logger:
105
+ print_freq: 100
106
+ save_checkpoint_freq: !!float 5e3
models/SRFlow/code/confs/SRFlow_DF2K_8X.yml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ #### general settings
18
+ name: train
19
+ use_tb_logger: true
20
+ model: SRFlow
21
+ distortion: sr
22
+ scale: 8
23
+ gpu_ids: [ 0 ]
24
+
25
+ #### datasets
26
+ datasets:
27
+ train:
28
+ name: CelebA_160_tr
29
+ mode: LRHR_PKL
30
+ dataroot_GT: ../datasets/DF2K-tr.pklv4
31
+ dataroot_LQ: ../datasets/DF2K-tr_X8.pklv4
32
+ quant: 32
33
+
34
+ use_shuffle: true
35
+ n_workers: 3 # per GPU
36
+ batch_size: 16
37
+ GT_size: 160
38
+ use_flip: true
39
+ color: RGB
40
+
41
+ val:
42
+ name: CelebA_160_va
43
+ mode: LRHR_PKL
44
+ dataroot_GT: ../datasets/DIV2K-va.pklv4
45
+ dataroot_LQ: ../datasets/DIV2K-va_X8.pklv4
46
+ quant: 32
47
+ n_max: 20
48
+
49
+ #### Test Settings
50
+ dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
51
+ dataroot_LR: ../datasets/div2k-validation-modcrop8-x8
52
+ model_path: ../pretrained_models/SRFlow_DF2K_8X.pth
53
+ heat: 0.9 # This is the standard deviation of the latent vectors
54
+
55
+ #### network structures
56
+ network_G:
57
+ which_model_G: SRFlowNet
58
+ in_nc: 3
59
+ out_nc: 3
60
+ nf: 64
61
+ nb: 23
62
+ upscale: 8
63
+ train_RRDB: false
64
+ train_RRDB_delay: 0.5
65
+
66
+ flow:
67
+ K: 16
68
+ L: 4
69
+ noInitialInj: true
70
+ coupling: CondAffineSeparatedAndCond
71
+ additionalFlowNoAffine: 2
72
+ split:
73
+ enable: true
74
+ fea_up0: true
75
+ stackRRDB:
76
+ blocks: [ 1, 3, 5, 7 ]
77
+ concat: true
78
+
79
+ #### path
80
+ path:
81
+ pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
82
+ strict_load: true
83
+ resume_state: auto
84
+
85
+ #### training settings: learning rate scheme, loss
86
+ train:
87
+ manual_seed: 10
88
+ lr_G: !!float 5e-4
89
+ weight_decay_G: 0
90
+ beta1: 0.9
91
+ beta2: 0.99
92
+ lr_scheme: MultiStepLR
93
+ warmup_iter: -1 # no warm up
94
+ lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
95
+ lr_gamma: 0.5
96
+
97
+ niter: 200000
98
+ val_freq: 40000
99
+
100
+ #### validation settings
101
+ val:
102
+ heats: [ 0.0, 0.5, 0.75, 1.0 ]
103
+ n_sample: 3
104
+
105
+ test:
106
+ heats: [ 0.0, 0.7, 0.8, 0.9 ]
107
+
108
+ #### logger
109
+ logger:
110
+ # Debug print_freq: 100
111
+ print_freq: 100
112
+ save_checkpoint_freq: !!float 1e3
models/SRFlow/code/data/LRHR_PKL_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ import subprocess
19
+ import torch.utils.data as data
20
+ import numpy as np
21
+ import time
22
+ import torch
23
+
24
+ import pickle
25
+
26
+
27
+ class LRHR_PKLDataset(data.Dataset):
28
+ def __init__(self, opt):
29
+ super(LRHR_PKLDataset, self).__init__()
30
+ self.opt = opt
31
+ self.crop_size = opt.get("GT_size", None)
32
+ self.scale = None
33
+ self.random_scale_list = [1]
34
+
35
+ hr_file_path = opt["dataroot_GT"]
36
+ lr_file_path = opt["dataroot_LQ"]
37
+ y_labels_file_path = opt['dataroot_y_labels']
38
+
39
+ gpu = True
40
+ augment = True
41
+
42
+ self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
43
+ self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
44
+ self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
45
+ self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
46
+
47
+ n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
48
+
49
+ t = time.time()
50
+ self.lr_images = self.load_pkls(lr_file_path, n_max)
51
+ self.hr_images = self.load_pkls(hr_file_path, n_max)
52
+
53
+ min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
54
+ max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
55
+
56
+ min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
57
+ max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
58
+
59
+ t = time.time() - t
60
+ print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
61
+ format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
62
+ print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
63
+ format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
64
+
65
+ self.gpu = gpu
66
+ self.augment = augment
67
+
68
+ self.measures = None
69
+
70
+ def load_pkls(self, path, n_max):
71
+ assert os.path.isfile(path), path
72
+ images = []
73
+ with open(path, "rb") as f:
74
+ images += pickle.load(f)
75
+ assert len(images) > 0, path
76
+ images = images[:n_max]
77
+ images = [np.transpose(image, [2, 0, 1]) for image in images]
78
+ return images
79
+
80
+ def __len__(self):
81
+ return len(self.hr_images)
82
+
83
+ def __getitem__(self, item):
84
+ hr = self.hr_images[item]
85
+ lr = self.lr_images[item]
86
+
87
+ if self.scale == None:
88
+ self.scale = hr.shape[1] // lr.shape[1]
89
+ assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
90
+
91
+ if self.use_crop:
92
+ hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
93
+
94
+ if self.center_crop_hr_size:
95
+ hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
96
+
97
+ if self.use_flip:
98
+ hr, lr = random_flip(hr, lr)
99
+
100
+ if self.use_rot:
101
+ hr, lr = random_rotation(hr, lr)
102
+
103
+ hr = hr / 255.0
104
+ lr = lr / 255.0
105
+
106
+ if self.measures is None or np.random.random() < 0.05:
107
+ if self.measures is None:
108
+ self.measures = {}
109
+ self.measures['hr_means'] = np.mean(hr)
110
+ self.measures['hr_stds'] = np.std(hr)
111
+ self.measures['lr_means'] = np.mean(lr)
112
+ self.measures['lr_stds'] = np.std(lr)
113
+
114
+ hr = torch.Tensor(hr)
115
+ lr = torch.Tensor(lr)
116
+
117
+ # if self.gpu:
118
+ # hr = hr.cuda()
119
+ # lr = lr.cuda()
120
+
121
+ return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
122
+
123
+ def print_and_reset(self, tag):
124
+ m = self.measures
125
+ kvs = []
126
+ for k in sorted(m.keys()):
127
+ kvs.append("{}={:.2f}".format(k, m[k]))
128
+ print("[KPI] " + tag + ": " + ", ".join(kvs))
129
+ self.measures = None
130
+
131
+
132
+ def random_flip(img, seg):
133
+ random_choice = np.random.choice([True, False])
134
+ img = img if random_choice else np.flip(img, 2).copy()
135
+ seg = seg if random_choice else np.flip(seg, 2).copy()
136
+ return img, seg
137
+
138
+
139
+ def random_rotation(img, seg):
140
+ random_choice = np.random.choice([0, 1, 3])
141
+ img = np.rot90(img, random_choice, axes=(1, 2)).copy()
142
+ seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
143
+ return img, seg
144
+
145
+
146
+ def random_crop(hr, lr, size_hr, scale, random):
147
+ size_lr = size_hr // scale
148
+
149
+ size_lr_x = lr.shape[1]
150
+ size_lr_y = lr.shape[2]
151
+
152
+ start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
153
+ start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
154
+
155
+ # LR Patch
156
+ lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
157
+
158
+ # HR Patch
159
+ start_x_hr = start_x_lr * scale
160
+ start_y_hr = start_y_lr * scale
161
+ hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
162
+
163
+ return hr_patch, lr_patch
164
+
165
+
166
+ def center_crop(img, size):
167
+ assert img.shape[1] == img.shape[2], img.shape
168
+ border_double = img.shape[1] - size
169
+ assert border_double % 2 == 0, (img.shape, size)
170
+ border = border_double // 2
171
+ return img[:, border:-border, border:-border]
172
+
173
+
174
+ def center_crop_tensor(img, size):
175
+ assert img.shape[2] == img.shape[3], img.shape
176
+ border_double = img.shape[2] - size
177
+ assert border_double % 2 == 0, (img.shape, size)
178
+ border = border_double // 2
179
+ return img[:, :, border:-border, border:-border]
models/SRFlow/code/data/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ '''create dataset and dataloader'''
18
+ import logging
19
+ import torch
20
+ import torch.utils.data
21
+
22
+
23
+ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
24
+ phase = dataset_opt.get('phase', 'test')
25
+ if phase == 'train':
26
+ gpu_ids = opt.get('gpu_ids', None)
27
+ gpu_ids = gpu_ids if gpu_ids else []
28
+ num_workers = dataset_opt['n_workers'] * len(gpu_ids)
29
+ batch_size = dataset_opt['batch_size']
30
+ shuffle = True
31
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
32
+ num_workers=num_workers, sampler=sampler, drop_last=True,
33
+ pin_memory=False)
34
+ else:
35
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
36
+ pin_memory=True)
37
+
38
+
39
+ def create_dataset(dataset_opt):
40
+ print(dataset_opt)
41
+ mode = dataset_opt['mode']
42
+ if mode == 'LRHR_PKL':
43
+ from data.LRHR_PKL_dataset import LRHR_PKLDataset as D
44
+ else:
45
+ raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
46
+ dataset = D(dataset_opt)
47
+
48
+ logger = logging.getLogger('base')
49
+ logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
50
+ dataset_opt['name']))
51
+ return dataset
models/SRFlow/code/demo_on_pretrained.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/SRFlow/code/imresize.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/fatheral/matlab_imresize
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) 2020 Alex
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+
26
+ from __future__ import print_function
27
+ import numpy as np
28
+ from math import ceil, floor
29
+
30
+
31
+ def deriveSizeFromScale(img_shape, scale):
32
+ output_shape = []
33
+ for k in range(2):
34
+ output_shape.append(int(ceil(scale[k] * img_shape[k])))
35
+ return output_shape
36
+
37
+
38
+ def deriveScaleFromSize(img_shape_in, img_shape_out):
39
+ scale = []
40
+ for k in range(2):
41
+ scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
42
+ return scale
43
+
44
+
45
+ def triangle(x):
46
+ x = np.array(x).astype(np.float64)
47
+ lessthanzero = np.logical_and((x >= -1), x < 0)
48
+ greaterthanzero = np.logical_and((x <= 1), x >= 0)
49
+ f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
50
+ return f
51
+
52
+
53
+ def cubic(x):
54
+ x = np.array(x).astype(np.float64)
55
+ absx = np.absolute(x)
56
+ absx2 = np.multiply(absx, absx)
57
+ absx3 = np.multiply(absx2, absx)
58
+ f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
59
+ (1 < absx) & (absx <= 2))
60
+ return f
61
+
62
+
63
+ def contributions(in_length, out_length, scale, kernel, k_width):
64
+ if scale < 1:
65
+ h = lambda x: scale * kernel(scale * x)
66
+ kernel_width = 1.0 * k_width / scale
67
+ else:
68
+ h = kernel
69
+ kernel_width = k_width
70
+ x = np.arange(1, out_length + 1).astype(np.float64)
71
+ u = x / scale + 0.5 * (1 - 1 / scale)
72
+ left = np.floor(u - kernel_width / 2)
73
+ P = int(ceil(kernel_width)) + 2
74
+ ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
75
+ indices = ind.astype(np.int32)
76
+ weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
77
+ weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
78
+ aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
79
+ indices = aux[np.mod(indices, aux.size)]
80
+ ind2store = np.nonzero(np.any(weights, axis=0))
81
+ weights = weights[:, ind2store]
82
+ indices = indices[:, ind2store]
83
+ return weights, indices
84
+
85
+
86
+ def imresizemex(inimg, weights, indices, dim):
87
+ in_shape = inimg.shape
88
+ w_shape = weights.shape
89
+ out_shape = list(in_shape)
90
+ out_shape[dim] = w_shape[0]
91
+ outimg = np.zeros(out_shape)
92
+ if dim == 0:
93
+ for i_img in range(in_shape[1]):
94
+ for i_w in range(w_shape[0]):
95
+ w = weights[i_w, :]
96
+ ind = indices[i_w, :]
97
+ im_slice = inimg[ind, i_img].astype(np.float64)
98
+ outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
99
+ elif dim == 1:
100
+ for i_img in range(in_shape[0]):
101
+ for i_w in range(w_shape[0]):
102
+ w = weights[i_w, :]
103
+ ind = indices[i_w, :]
104
+ im_slice = inimg[i_img, ind].astype(np.float64)
105
+ outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
106
+ if inimg.dtype == np.uint8:
107
+ outimg = np.clip(outimg, 0, 255)
108
+ return np.around(outimg).astype(np.uint8)
109
+ else:
110
+ return outimg
111
+
112
+
113
+ def imresizevec(inimg, weights, indices, dim):
114
+ wshape = weights.shape
115
+ if dim == 0:
116
+ weights = weights.reshape((wshape[0], wshape[2], 1, 1))
117
+ outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
118
+ elif dim == 1:
119
+ weights = weights.reshape((1, wshape[0], wshape[2], 1))
120
+ outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
121
+ if inimg.dtype == np.uint8:
122
+ outimg = np.clip(outimg, 0, 255)
123
+ return np.around(outimg).astype(np.uint8)
124
+ else:
125
+ return outimg
126
+
127
+
128
+ def resizeAlongDim(A, dim, weights, indices, mode="vec"):
129
+ if mode == "org":
130
+ out = imresizemex(A, weights, indices, dim)
131
+ else:
132
+ out = imresizevec(A, weights, indices, dim)
133
+ return out
134
+
135
+
136
+ def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
137
+ if method is 'bicubic':
138
+ kernel = cubic
139
+ elif method is 'bilinear':
140
+ kernel = triangle
141
+ else:
142
+ print('Error: Unidentified method supplied')
143
+
144
+ kernel_width = 4.0
145
+ # Fill scale and output_size
146
+ if scalar_scale is not None:
147
+ scalar_scale = float(scalar_scale)
148
+ scale = [scalar_scale, scalar_scale]
149
+ output_size = deriveSizeFromScale(I.shape, scale)
150
+ elif output_shape is not None:
151
+ scale = deriveScaleFromSize(I.shape, output_shape)
152
+ output_size = list(output_shape)
153
+ else:
154
+ print('Error: scalar_scale OR output_shape should be defined!')
155
+ return
156
+ scale_np = np.array(scale)
157
+ order = np.argsort(scale_np)
158
+ weights = []
159
+ indices = []
160
+ for k in range(2):
161
+ w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
162
+ weights.append(w)
163
+ indices.append(ind)
164
+ B = np.copy(I)
165
+ flag2D = False
166
+ if B.ndim == 2:
167
+ B = np.expand_dims(B, axis=2)
168
+ flag2D = True
169
+ for k in range(2):
170
+ dim = order[k]
171
+ B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
172
+ if flag2D:
173
+ B = np.squeeze(B, axis=2)
174
+ return B
175
+
176
+
177
+ def convertDouble2Byte(I):
178
+ B = np.clip(I, 0.0, 1.0)
179
+ B = 255 * B
180
+ return np.around(B).astype(np.uint8)
models/SRFlow/code/models/SRFlow_model.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import logging
18
+ from collections import OrderedDict
19
+ from utils.util import get_resume_paths, opt_get
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
24
+ import models.networks as networks
25
+ import models.lr_scheduler as lr_scheduler
26
+ from .base_model import BaseModel
27
+
28
+ logger = logging.getLogger('base')
29
+
30
+
31
+ class SRFlowModel(BaseModel):
32
+ def __init__(self, opt, step):
33
+ super(SRFlowModel, self).__init__(opt)
34
+ self.opt = opt
35
+
36
+ self.heats = opt['val']['heats']
37
+ self.n_sample = opt['val']['n_sample']
38
+ self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size'])
39
+ self.hr_size = 256 if self.hr_size is None else self.hr_size
40
+ self.lr_size = self.hr_size // opt['scale']
41
+
42
+ if opt['dist']:
43
+ self.rank = torch.distributed.get_rank()
44
+ else:
45
+ self.rank = -1 # non dist training
46
+ train_opt = opt['train']
47
+
48
+ # define network and load pretrained models
49
+ self.netG = networks.define_Flow(opt, step).to(self.device)
50
+ if opt['dist']:
51
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
52
+ else:
53
+ self.netG = DataParallel(self.netG)
54
+ # print network
55
+ self.print_network()
56
+
57
+ if opt_get(opt, ['path', 'resume_state'], 1) is not None:
58
+ self.load()
59
+ else:
60
+ print("WARNING: skipping initial loading, due to resume_state None")
61
+
62
+ if self.is_train:
63
+ self.netG.train()
64
+
65
+ self.init_optimizer_and_scheduler(train_opt)
66
+ self.log_dict = OrderedDict()
67
+
68
+ def to(self, device):
69
+ self.device = device
70
+ self.netG.to(device)
71
+
72
+ def init_optimizer_and_scheduler(self, train_opt):
73
+ # optimizers
74
+ self.optimizers = []
75
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
76
+ optim_params_RRDB = []
77
+ optim_params_other = []
78
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
79
+ print(k, v.requires_grad)
80
+ if v.requires_grad:
81
+ if '.RRDB.' in k:
82
+ optim_params_RRDB.append(v)
83
+ print('opt', k)
84
+ else:
85
+ optim_params_other.append(v)
86
+ if self.rank <= 0:
87
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
88
+
89
+ print('rrdb params', len(optim_params_RRDB))
90
+
91
+ self.optimizer_G = torch.optim.Adam(
92
+ [
93
+ {"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'],
94
+ 'beta2': train_opt['beta2'], 'weight_decay': wd_G},
95
+ {"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
96
+ 'beta1': train_opt['beta1'],
97
+ 'beta2': train_opt['beta2'], 'weight_decay': wd_G}
98
+ ],
99
+ )
100
+
101
+ self.optimizers.append(self.optimizer_G)
102
+ # schedulers
103
+ if train_opt['lr_scheme'] == 'MultiStepLR':
104
+ for optimizer in self.optimizers:
105
+ self.schedulers.append(
106
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
107
+ restarts=train_opt['restarts'],
108
+ weights=train_opt['restart_weights'],
109
+ gamma=train_opt['lr_gamma'],
110
+ clear_state=train_opt['clear_state'],
111
+ lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
112
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
113
+ for optimizer in self.optimizers:
114
+ self.schedulers.append(
115
+ lr_scheduler.CosineAnnealingLR_Restart(
116
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
117
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
118
+ else:
119
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
120
+
121
+ def add_optimizer_and_scheduler_RRDB(self, train_opt):
122
+ # optimizers
123
+ assert len(self.optimizers) == 1, self.optimizers
124
+ assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1]
125
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
126
+ if v.requires_grad:
127
+ if '.RRDB.' in k:
128
+ self.optimizer_G.param_groups[1]['params'].append(v)
129
+ assert len(self.optimizer_G.param_groups[1]['params']) > 0
130
+
131
+ def feed_data(self, data, need_GT=True):
132
+ self.var_L = data['LQ'].to(self.device) # LQ
133
+ if need_GT:
134
+ self.real_H = data['GT'].to(self.device) # GT
135
+
136
+ def optimize_parameters(self, step):
137
+
138
+ train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
139
+ if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
140
+ and not self.netG.module.RRDB_training:
141
+ if self.netG.module.set_rrdb_training(True):
142
+ self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
143
+
144
+ # self.print_rrdb_state()
145
+
146
+ self.netG.train()
147
+ self.log_dict = OrderedDict()
148
+ self.optimizer_G.zero_grad()
149
+
150
+ losses = {}
151
+ weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
152
+ weight_fl = 1 if weight_fl is None else weight_fl
153
+ if weight_fl > 0:
154
+ z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
155
+ nll_loss = torch.mean(nll)
156
+ losses['nll_loss'] = nll_loss * weight_fl
157
+
158
+ weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
159
+ if weight_l1 > 0:
160
+ z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
161
+ sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True)
162
+ l1_loss = (sr - self.real_H).abs().mean()
163
+ losses['l1_loss'] = l1_loss * weight_l1
164
+
165
+ total_loss = sum(losses.values())
166
+ total_loss.backward()
167
+ self.optimizer_G.step()
168
+
169
+ mean = total_loss.item()
170
+ return mean
171
+
172
+ def print_rrdb_state(self):
173
+ for name, param in self.netG.module.named_parameters():
174
+ if "RRDB.conv_first.weight" in name:
175
+ print(name, param.requires_grad, param.data.abs().sum())
176
+ print('params', [len(p['params']) for p in self.optimizer_G.param_groups])
177
+
178
+ def test(self):
179
+ self.netG.eval()
180
+ self.fake_H = {}
181
+ for heat in self.heats:
182
+ for i in range(self.n_sample):
183
+ z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
184
+ with torch.no_grad():
185
+ self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True)
186
+ with torch.no_grad():
187
+ _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
188
+ self.netG.train()
189
+ return nll.mean().item()
190
+
191
+ def get_encode_nll(self, lq, gt):
192
+ self.netG.eval()
193
+ with torch.no_grad():
194
+ _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
195
+ self.netG.train()
196
+ return nll.mean().item()
197
+
198
+ def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
199
+ return self.get_sr_with_z(lq, heat, seed, z, epses)[0]
200
+
201
+ def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
202
+ self.netG.eval()
203
+ with torch.no_grad():
204
+ z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
205
+ self.netG.train()
206
+ return z
207
+
208
+ def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
209
+ self.netG.eval()
210
+ with torch.no_grad():
211
+ z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
212
+ self.netG.train()
213
+ return z, nll
214
+
215
+ def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
216
+ self.netG.eval()
217
+
218
+ z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z
219
+
220
+ with torch.no_grad():
221
+ sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses)
222
+ self.netG.train()
223
+ return sr, z
224
+
225
+ def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
226
+ if seed: torch.manual_seed(seed)
227
+ if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
228
+ C = self.netG.module.flowUpsamplerNet.C
229
+ H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH)
230
+ W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW)
231
+ z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros(
232
+ (batch_size, C, H, W))
233
+ else:
234
+ L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
235
+ fac = 2 ** (L - 3)
236
+ z_size = int(self.lr_size // (2 ** (L - 3)))
237
+ z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
238
+ return z
239
+
240
+ def get_current_log(self):
241
+ return self.log_dict
242
+
243
+ def get_current_visuals(self, need_GT=True):
244
+ out_dict = OrderedDict()
245
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
246
+ for heat in self.heats:
247
+ for i in range(self.n_sample):
248
+ out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu()
249
+ if need_GT:
250
+ out_dict['GT'] = self.real_H.detach()[0].float().cpu()
251
+ return out_dict
252
+
253
+ def print_network(self):
254
+ s, n = self.get_network_description(self.netG)
255
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
256
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
257
+ self.netG.module.__class__.__name__)
258
+ else:
259
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
260
+ if self.rank <= 0:
261
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
262
+ logger.info(s)
263
+
264
+ def load(self):
265
+ _, get_resume_model_path = get_resume_paths(self.opt)
266
+ if get_resume_model_path is not None:
267
+ self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None)
268
+ return
269
+
270
+ load_path_G = self.opt['path']['pretrain_model_G']
271
+ load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB'
272
+ if load_path_G is not None:
273
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
274
+ self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True),
275
+ submodule=load_submodule)
276
+
277
+ def save(self, iter_label):
278
+ self.save_network(self.netG, 'G', iter_label)
models/SRFlow/code/models/SR_model.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import logging
18
+ from collections import OrderedDict
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
23
+ import models.networks as networks
24
+ import models.lr_scheduler as lr_scheduler
25
+ from utils.util import opt_get
26
+ from .base_model import BaseModel
27
+ from models.modules.loss import CharbonnierLoss
28
+
29
+ logger = logging.getLogger('base')
30
+
31
+
32
+ class SRModel(BaseModel):
33
+ def __init__(self, opt, step):
34
+ super(SRModel, self).__init__(opt)
35
+
36
+ self.step = step
37
+
38
+ if opt['dist']:
39
+ self.rank = torch.distributed.get_rank()
40
+ else:
41
+ self.rank = -1 # non dist training
42
+ train_opt = opt['train']
43
+
44
+ # define network and load pretrained_models models
45
+ self.netG = networks.define_G(opt).to(self.device)
46
+ if opt['dist']:
47
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
48
+ else:
49
+ self.netG = DataParallel(self.netG)
50
+ # print network
51
+ self.print_network()
52
+ self.load()
53
+
54
+ if self.is_train:
55
+ self.netG.train()
56
+
57
+ # loss
58
+ loss_type = train_opt['pixel_criterion']
59
+ if loss_type == 'l1':
60
+ self.cri_pix = nn.L1Loss().to(self.device)
61
+ elif loss_type == 'l2':
62
+ self.cri_pix = nn.MSELoss().to(self.device)
63
+ elif loss_type == 'cb':
64
+ self.cri_pix = CharbonnierLoss().to(self.device)
65
+ else:
66
+ raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
67
+ self.l_pix_w = train_opt['pixel_weight']
68
+
69
+ # optimizers
70
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
71
+ optim_params = []
72
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
73
+ if v.requires_grad:
74
+ optim_params.append(v)
75
+ else:
76
+ if self.rank <= 0:
77
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
78
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
79
+ weight_decay=wd_G,
80
+ betas=(train_opt['beta1'], train_opt['beta2']))
81
+ self.optimizers.append(self.optimizer_G)
82
+
83
+ # schedulers
84
+ if train_opt['lr_scheme'] == 'MultiStepLR':
85
+ for optimizer in self.optimizers:
86
+ self.schedulers.append(
87
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
88
+ restarts=train_opt['restarts'],
89
+ weights=train_opt['restart_weights'],
90
+ gamma=train_opt['lr_gamma'],
91
+ clear_state=train_opt['clear_state']))
92
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
93
+ for optimizer in self.optimizers:
94
+ self.schedulers.append(
95
+ lr_scheduler.CosineAnnealingLR_Restart(
96
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
97
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
98
+ else:
99
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
100
+
101
+ self.log_dict = OrderedDict()
102
+
103
+ def feed_data(self, data, need_GT=True):
104
+ self.var_L = data['LQ'].to(self.device) # LQ
105
+ if need_GT:
106
+ self.real_H = data['GT'].to(self.device) # GT
107
+
108
+ def to(self, device):
109
+ self.device = device
110
+ self.netG.to(device)
111
+
112
+ def optimize_parameters(self, step):
113
+ def getEnv(name): import os; return True if name in os.environ.keys() else False
114
+
115
+ if getEnv("DEBUG_FEED_IMAGES"):
116
+ import imageio
117
+ import random
118
+ i = random.randint(0, 10000)
119
+ label = self.var_L.cpu().numpy()[0].transpose([1, 2, 0])
120
+ print("var_L", label.min(), label.max(), label.shape)
121
+ imageio.imwrite("/tmp/{}_l.png".format(i), label)
122
+ image = self.real_H.cpu().numpy()[0].transpose([1, 2, 0])
123
+ print("self.real_H", image.min(), image.max(), image.shape)
124
+ imageio.imwrite("/tmp/{}_gt.png".format(i), image)
125
+ self.optimizer_G.zero_grad()
126
+ self.fake_H = self.netG(self.var_L)
127
+ l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H.to(self.fake_H.device))
128
+ l_pix.backward()
129
+ self.optimizer_G.step()
130
+
131
+ # set log
132
+ self.log_dict['l_pix'] = l_pix.item()
133
+
134
+ def test(self):
135
+ self.netG.eval()
136
+ with torch.no_grad():
137
+ self.fake_H = self.netG(self.var_L)
138
+ self.netG.train()
139
+
140
+ def get_encode_nll(self, lq, gt):
141
+ return torch.ones(1) * 1e14
142
+
143
+ def get_sr(self, lq, heat=None, seed=None):
144
+ self.netG.eval()
145
+ sr = self.netG(lq)
146
+ self.netG.train()
147
+ return sr
148
+
149
+ def test_x8(self):
150
+ # from https://github.com/thstkdgus35/EDSR-PyTorch
151
+ self.netG.eval()
152
+
153
+ def _transform(v, op):
154
+ # if self.precision != 'single': v = v.float()
155
+ v2np = v.data.cpu().numpy()
156
+ if op == 'v':
157
+ tfnp = v2np[:, :, :, ::-1].copy()
158
+ elif op == 'h':
159
+ tfnp = v2np[:, :, ::-1, :].copy()
160
+ elif op == 't':
161
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
162
+
163
+ ret = torch.Tensor(tfnp).to(self.device)
164
+ # if self.precision == 'half': ret = ret.half()
165
+
166
+ return ret
167
+
168
+ lr_list = [self.var_L]
169
+ for tf in 'v', 'h', 't':
170
+ lr_list.extend([_transform(t, tf) for t in lr_list])
171
+ with torch.no_grad():
172
+ sr_list = [self.netG(aug) for aug in lr_list]
173
+ for i in range(len(sr_list)):
174
+ if i > 3:
175
+ sr_list[i] = _transform(sr_list[i], 't')
176
+ if i % 4 > 1:
177
+ sr_list[i] = _transform(sr_list[i], 'h')
178
+ if (i % 4) % 2 == 1:
179
+ sr_list[i] = _transform(sr_list[i], 'v')
180
+
181
+ output_cat = torch.cat(sr_list, dim=0)
182
+ self.fake_H = output_cat.mean(dim=0, keepdim=True)
183
+ self.netG.train()
184
+
185
+ def get_current_log(self):
186
+ return self.log_dict
187
+
188
+ def get_current_visuals(self, need_GT=True):
189
+ out_dict = OrderedDict()
190
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
191
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
192
+ if need_GT:
193
+ out_dict['GT'] = self.real_H.detach()[0].float().cpu()
194
+ return out_dict
195
+
196
+ def print_network(self):
197
+ s, n = self.get_network_description(self.netG)
198
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
199
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
200
+ self.netG.module.__class__.__name__)
201
+ else:
202
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
203
+ if self.rank <= 0:
204
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
205
+ logger.info(s)
206
+
207
+ def load(self):
208
+ load_path_G = self.opt['path']['pretrain_model_G']
209
+ if load_path_G is not None:
210
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
211
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
212
+
213
+ def save(self, iter_label):
214
+ self.save_network(self.netG, 'G', iter_label)
215
+
216
+ def get_encode_z_and_nll(self, *args, **kwargs):
217
+ return [], torch.zeros(1)
models/SRFlow/code/models/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import logging
3
+ import os
4
+
5
+ try:
6
+ import local_config
7
+ except:
8
+ local_config = None
9
+
10
+
11
+ logger = logging.getLogger('base')
12
+
13
+
14
+ def find_model_using_name(model_name):
15
+ # Given the option --model [modelname],
16
+ # the file "models/modelname_model.py"
17
+ # will be imported.
18
+ model_filename = "models." + model_name + "_model"
19
+ modellib = importlib.import_module(model_filename)
20
+
21
+ # In the file, the class called ModelNameModel() will
22
+ # be instantiated. It has to be a subclass of torch.nn.Module,
23
+ # and it is case-insensitive.
24
+ model = None
25
+ target_model_name = model_name.replace('_', '') + 'Model'
26
+ for name, cls in modellib.__dict__.items():
27
+ if name.lower() == target_model_name.lower():
28
+ model = cls
29
+
30
+ if model is None:
31
+ print(
32
+ "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
33
+ model_filename, target_model_name))
34
+ exit(0)
35
+
36
+ return model
37
+
38
+
39
+ def create_model(opt, step=0, **opt_kwargs):
40
+ if local_config is not None:
41
+ opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth'))
42
+
43
+ for k, v in opt_kwargs.items():
44
+ opt[k] = v
45
+
46
+ model = opt['model']
47
+
48
+ M = find_model_using_name(model)
49
+
50
+ m = M(opt, step)
51
+ logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
52
+ return m
models/SRFlow/code/models/base_model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ from collections import OrderedDict
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.parallel import DistributedDataParallel
22
+ import natsort
23
+ import glob
24
+
25
+
26
+ class BaseModel():
27
+ def __init__(self, opt):
28
+ self.opt = opt
29
+ self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu')
30
+ self.is_train = opt['is_train']
31
+ self.schedulers = []
32
+ self.optimizers = []
33
+
34
+ def feed_data(self, data):
35
+ pass
36
+
37
+ def optimize_parameters(self):
38
+ pass
39
+
40
+ def get_current_visuals(self):
41
+ pass
42
+
43
+ def get_current_losses(self):
44
+ pass
45
+
46
+ def print_network(self):
47
+ pass
48
+
49
+ def save(self, label):
50
+ pass
51
+
52
+ def load(self):
53
+ pass
54
+
55
+ def _set_lr(self, lr_groups_l):
56
+ ''' set learning rate for warmup,
57
+ lr_groups_l: list for lr_groups. each for a optimizer'''
58
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
59
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
60
+ param_group['lr'] = lr
61
+
62
+ def _get_init_lr(self):
63
+ # get the initial lr, which is set by the scheduler
64
+ init_lr_groups_l = []
65
+ for optimizer in self.optimizers:
66
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
67
+ return init_lr_groups_l
68
+
69
+ def update_learning_rate(self, cur_iter, warmup_iter=-1):
70
+ for scheduler in self.schedulers:
71
+ scheduler.step()
72
+ #### set up warm up learning rate
73
+ if cur_iter < warmup_iter:
74
+ # get initial lr for each group
75
+ init_lr_g_l = self._get_init_lr()
76
+ # modify warming-up learning rates
77
+ warm_up_lr_l = []
78
+ for init_lr_g in init_lr_g_l:
79
+ warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
80
+ # set learning rate
81
+ self._set_lr(warm_up_lr_l)
82
+
83
+ def get_current_learning_rate(self):
84
+ # return self.schedulers[0].get_lr()[0]
85
+ return self.optimizers[0].param_groups[0]['lr']
86
+
87
+ def get_network_description(self, network):
88
+ '''Get the string and total parameters of the network'''
89
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
90
+ network = network.module
91
+ s = str(network)
92
+ n = sum(map(lambda x: x.numel(), network.parameters()))
93
+ return s, n
94
+
95
+ def save_network(self, network, network_label, iter_label):
96
+ paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))),
97
+ reverse=True)
98
+ paths = [p for p in paths if
99
+ "latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])]
100
+ if len(paths) > 2:
101
+ for path in paths[2:]:
102
+ os.remove(path)
103
+ save_filename = '{}_{}.pth'.format(iter_label, network_label)
104
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
105
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
106
+ network = network.module
107
+ state_dict = network.state_dict()
108
+ for key, param in state_dict.items():
109
+ state_dict[key] = param.cpu()
110
+ torch.save(state_dict, save_path)
111
+
112
+ def load_network(self, load_path, network, strict=True, submodule=None):
113
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
114
+ network = network.module
115
+ if not (submodule is None or submodule.lower() == 'none'.lower()):
116
+ network = network.__getattr__(submodule)
117
+ load_net = torch.load(load_path)
118
+ load_net_clean = OrderedDict() # remove unnecessary 'module.'
119
+ for k, v in load_net.items():
120
+ if k.startswith('module.'):
121
+ load_net_clean[k[7:]] = v
122
+ else:
123
+ load_net_clean[k] = v
124
+ network.load_state_dict(load_net_clean, strict=strict)
125
+
126
+ def save_training_state(self, epoch, iter_step):
127
+ '''Saves training state during training, which will be used for resuming'''
128
+ state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
129
+ for s in self.schedulers:
130
+ state['schedulers'].append(s.state_dict())
131
+ for o in self.optimizers:
132
+ state['optimizers'].append(o.state_dict())
133
+ save_filename = '{}.state'.format(iter_step)
134
+ save_path = os.path.join(self.opt['path']['training_state'], save_filename)
135
+
136
+ paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")),
137
+ reverse=True)
138
+ paths = [p for p in paths if "latest_" not in p]
139
+ if len(paths) > 2:
140
+ for path in paths[2:]:
141
+ os.remove(path)
142
+
143
+ torch.save(state, save_path)
144
+
145
+ def resume_training(self, resume_state):
146
+ '''Resume the optimizers and schedulers for training'''
147
+ resume_optimizers = resume_state['optimizers']
148
+ resume_schedulers = resume_state['schedulers']
149
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
150
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
151
+ for i, o in enumerate(resume_optimizers):
152
+ self.optimizers[i].load_state_dict(o)
153
+ for i, s in enumerate(resume_schedulers):
154
+ self.schedulers[i].load_state_dict(s)
models/SRFlow/code/models/lr_scheduler.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import math
18
+ from collections import Counter
19
+ from collections import defaultdict
20
+ import torch
21
+ from torch.optim.lr_scheduler import _LRScheduler
22
+
23
+
24
+ class MultiStepLR_Restart(_LRScheduler):
25
+ def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
26
+ clear_state=False, last_epoch=-1, lr_steps_invese=None):
27
+ assert lr_steps_invese is not None, "Use empty list"
28
+ self.milestones = Counter(milestones)
29
+ self.lr_steps_inverse = Counter(lr_steps_invese)
30
+ self.gamma = gamma
31
+ self.clear_state = clear_state
32
+ self.restarts = restarts if restarts else [0]
33
+ self.restart_weights = weights if weights else [1]
34
+ assert len(self.restarts) == len(
35
+ self.restart_weights), 'restarts and their weights do not match.'
36
+ super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
37
+
38
+ def get_lr(self):
39
+ if self.last_epoch in self.restarts:
40
+ if self.clear_state:
41
+ self.optimizer.state = defaultdict(dict)
42
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
43
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
44
+ if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse:
45
+ return [group['lr'] for group in self.optimizer.param_groups]
46
+ return [
47
+ group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) *
48
+ (self.gamma ** (-self.lr_steps_inverse[self.last_epoch]))
49
+ for group in self.optimizer.param_groups
50
+ ]
51
+
52
+
53
+ class CosineAnnealingLR_Restart(_LRScheduler):
54
+ def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
55
+ self.T_period = T_period
56
+ self.T_max = self.T_period[0] # current T period
57
+ self.eta_min = eta_min
58
+ self.restarts = restarts if restarts else [0]
59
+ self.restart_weights = weights if weights else [1]
60
+ self.last_restart = 0
61
+ assert len(self.restarts) == len(
62
+ self.restart_weights), 'restarts and their weights do not match.'
63
+ super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
64
+
65
+ def get_lr(self):
66
+ if self.last_epoch == 0:
67
+ return self.base_lrs
68
+ elif self.last_epoch in self.restarts:
69
+ self.last_restart = self.last_epoch
70
+ self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
71
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
72
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
73
+ elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
74
+ return [
75
+ group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
76
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
77
+ ]
78
+ return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
79
+ (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
80
+ (group['lr'] - self.eta_min) + self.eta_min
81
+ for group in self.optimizer.param_groups]
82
+
83
+
84
+ if __name__ == "__main__":
85
+ optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
86
+ betas=(0.9, 0.99))
87
+ ##############################
88
+ # MultiStepLR_Restart
89
+ ##############################
90
+ ## Original
91
+ lr_steps = [200000, 400000, 600000, 800000]
92
+ restarts = None
93
+ restart_weights = None
94
+
95
+ ## two
96
+ lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
97
+ restarts = [500000]
98
+ restart_weights = [1]
99
+
100
+ ## four
101
+ lr_steps = [
102
+ 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
103
+ 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
104
+ ]
105
+ restarts = [250000, 500000, 750000]
106
+ restart_weights = [1, 1, 1]
107
+
108
+ scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
109
+ clear_state=False)
110
+
111
+ ##############################
112
+ # Cosine Annealing Restart
113
+ ##############################
114
+ ## two
115
+ T_period = [500000, 500000]
116
+ restarts = [500000]
117
+ restart_weights = [1]
118
+
119
+ ## four
120
+ T_period = [250000, 250000, 250000, 250000]
121
+ restarts = [250000, 500000, 750000]
122
+ restart_weights = [1, 1, 1]
123
+
124
+ scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
125
+ weights=restart_weights)
126
+
127
+ ##############################
128
+ # Draw figure
129
+ ##############################
130
+ N_iter = 1000000
131
+ lr_l = list(range(N_iter))
132
+ for i in range(N_iter):
133
+ scheduler.step()
134
+ current_lr = optimizer.param_groups[0]['lr']
135
+ lr_l[i] = current_lr
136
+
137
+ import matplotlib as mpl
138
+ from matplotlib import pyplot as plt
139
+ import matplotlib.ticker as mtick
140
+
141
+ mpl.style.use('default')
142
+ import seaborn
143
+
144
+ seaborn.set(style='whitegrid')
145
+ seaborn.set_context('paper')
146
+
147
+ plt.figure(1)
148
+ plt.subplot(111)
149
+ plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
150
+ plt.title('Title', fontsize=16, color='k')
151
+ plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
152
+ legend = plt.legend(loc='upper right', shadow=False)
153
+ ax = plt.gca()
154
+ labels = ax.get_xticks().tolist()
155
+ for k, v in enumerate(labels):
156
+ labels[k] = str(int(v / 1000)) + 'K'
157
+ ax.set_xticklabels(labels)
158
+ ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
159
+
160
+ ax.set_ylabel('Learning rate')
161
+ ax.set_xlabel('Iteration')
162
+ fig = plt.gcf()
163
+ plt.show()
models/SRFlow/code/models/modules/FlowActNorms.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ from torch import nn as nn
19
+
20
+ from models.modules import thops
21
+
22
+
23
+ class _ActNorm(nn.Module):
24
+ """
25
+ Activation Normalization
26
+ Initialize the bias and scale with a given minibatch,
27
+ so that the output per-channel have zero mean and unit variance for that.
28
+
29
+ After initialization, `bias` and `logs` will be trained as parameters.
30
+ """
31
+
32
+ def __init__(self, num_features, scale=1.):
33
+ super().__init__()
34
+ # register mean and scale
35
+ size = [1, num_features, 1, 1]
36
+ self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
37
+ self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
38
+ self.num_features = num_features
39
+ self.scale = float(scale)
40
+ self.inited = False
41
+
42
+ def _check_input_dim(self, input):
43
+ return NotImplemented
44
+
45
+ def initialize_parameters(self, input):
46
+ self._check_input_dim(input)
47
+ if not self.training:
48
+ return
49
+ if (self.bias != 0).any():
50
+ self.inited = True
51
+ return
52
+ assert input.device == self.bias.device, (input.device, self.bias.device)
53
+ with torch.no_grad():
54
+ bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
55
+ vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
56
+ logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
57
+ self.bias.data.copy_(bias.data)
58
+ self.logs.data.copy_(logs.data)
59
+ self.inited = True
60
+
61
+ def _center(self, input, reverse=False, offset=None):
62
+ bias = self.bias
63
+
64
+ if offset is not None:
65
+ bias = bias + offset
66
+
67
+ if not reverse:
68
+ return input + bias
69
+ else:
70
+ return input - bias
71
+
72
+ def _scale(self, input, logdet=None, reverse=False, offset=None):
73
+ logs = self.logs
74
+
75
+ if offset is not None:
76
+ logs = logs + offset
77
+
78
+ if not reverse:
79
+ input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
80
+ # input = input * torch.exp(logs+logs_offset)
81
+ else:
82
+ input = input * torch.exp(-logs)
83
+ if logdet is not None:
84
+ """
85
+ logs is log_std of `mean of channels`
86
+ so we need to multiply pixels
87
+ """
88
+ dlogdet = thops.sum(logs) * thops.pixels(input)
89
+ if reverse:
90
+ dlogdet *= -1
91
+ logdet = logdet + dlogdet
92
+ return input, logdet
93
+
94
+ def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
95
+ if not self.inited:
96
+ self.initialize_parameters(input)
97
+ self._check_input_dim(input)
98
+
99
+ if offset_mask is not None:
100
+ logs_offset *= offset_mask
101
+ bias_offset *= offset_mask
102
+ # no need to permute dims as old version
103
+ if not reverse:
104
+ # center and scale
105
+
106
+ # self.input = input
107
+ input = self._center(input, reverse, bias_offset)
108
+ input, logdet = self._scale(input, logdet, reverse, logs_offset)
109
+ else:
110
+ # scale and center
111
+ input, logdet = self._scale(input, logdet, reverse, logs_offset)
112
+ input = self._center(input, reverse, bias_offset)
113
+ return input, logdet
114
+
115
+
116
+ class ActNorm2d(_ActNorm):
117
+ def __init__(self, num_features, scale=1.):
118
+ super().__init__(num_features, scale)
119
+
120
+ def _check_input_dim(self, input):
121
+ assert len(input.size()) == 4
122
+ assert input.size(1) == self.num_features, (
123
+ "[ActNorm]: input should be in shape as `BCHW`,"
124
+ " channels should be {} rather than {}".format(
125
+ self.num_features, input.size()))
126
+
127
+
128
+ class MaskedActNorm2d(ActNorm2d):
129
+ def __init__(self, num_features, scale=1.):
130
+ super().__init__(num_features, scale)
131
+
132
+ def forward(self, input, mask, logdet=None, reverse=False):
133
+
134
+ assert mask.dtype == torch.bool
135
+ output, logdet_out = super().forward(input, logdet, reverse)
136
+
137
+ input[mask] = output[mask]
138
+ logdet[mask] = logdet_out[mask]
139
+
140
+ return input, logdet
141
+
models/SRFlow/code/models/modules/FlowAffineCouplingsAblation.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ from torch import nn as nn
19
+
20
+ from models.modules import thops
21
+ from models.modules.flow import Conv2d, Conv2dZeros
22
+ from utils.util import opt_get
23
+
24
+
25
+ class CondAffineSeparatedAndCond(nn.Module):
26
+ def __init__(self, in_channels, opt):
27
+ super().__init__()
28
+ self.need_features = True
29
+ self.in_channels = in_channels
30
+ self.in_channels_rrdb = 320
31
+ self.kernel_hidden = 1
32
+ self.affine_eps = 0.0001
33
+ self.n_hidden_layers = 1
34
+ hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
35
+ self.hidden_channels = 64 if hidden_channels is None else hidden_channels
36
+
37
+ self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
38
+
39
+ self.channels_for_nn = self.in_channels // 2
40
+ self.channels_for_co = self.in_channels - self.channels_for_nn
41
+
42
+ if self.channels_for_nn is None:
43
+ self.channels_for_nn = self.in_channels // 2
44
+
45
+ self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
46
+ out_channels=self.channels_for_co * 2,
47
+ hidden_channels=self.hidden_channels,
48
+ kernel_hidden=self.kernel_hidden,
49
+ n_hidden_layers=self.n_hidden_layers)
50
+
51
+ self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
52
+ out_channels=self.in_channels * 2,
53
+ hidden_channels=self.hidden_channels,
54
+ kernel_hidden=self.kernel_hidden,
55
+ n_hidden_layers=self.n_hidden_layers)
56
+
57
+ def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
58
+ if not reverse:
59
+ z = input
60
+ assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
61
+
62
+ # Feature Conditional
63
+ scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
64
+ z = z + shiftFt
65
+ z = z * scaleFt
66
+ logdet = logdet + self.get_logdet(scaleFt)
67
+
68
+ # Self Conditional
69
+ z1, z2 = self.split(z)
70
+ scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
71
+ self.asserts(scale, shift, z1, z2)
72
+ z2 = z2 + shift
73
+ z2 = z2 * scale
74
+
75
+ logdet = logdet + self.get_logdet(scale)
76
+ z = thops.cat_feature(z1, z2)
77
+ output = z
78
+ else:
79
+ z = input
80
+
81
+ # Self Conditional
82
+ z1, z2 = self.split(z)
83
+ scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
84
+ self.asserts(scale, shift, z1, z2)
85
+ z2 = z2 / scale
86
+ z2 = z2 - shift
87
+ z = thops.cat_feature(z1, z2)
88
+ logdet = logdet - self.get_logdet(scale)
89
+
90
+ # Feature Conditional
91
+ scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
92
+ z = z / scaleFt
93
+ z = z - shiftFt
94
+ logdet = logdet - self.get_logdet(scaleFt)
95
+
96
+ output = z
97
+ return output, logdet
98
+
99
+ def asserts(self, scale, shift, z1, z2):
100
+ assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
101
+ assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
102
+ assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
103
+ assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
104
+
105
+ def get_logdet(self, scale):
106
+ return thops.sum(torch.log(scale), dim=[1, 2, 3])
107
+
108
+ def feature_extract(self, z, f):
109
+ h = f(z)
110
+ shift, scale = thops.split_feature(h, "cross")
111
+ scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
112
+ return scale, shift
113
+
114
+ def feature_extract_aff(self, z1, ft, f):
115
+ z = torch.cat([z1, ft], dim=1)
116
+ h = f(z)
117
+ shift, scale = thops.split_feature(h, "cross")
118
+ scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
119
+ return scale, shift
120
+
121
+ def split(self, z):
122
+ z1 = z[:, :self.channels_for_nn]
123
+ z2 = z[:, self.channels_for_nn:]
124
+ assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
125
+ return z1, z2
126
+
127
+ def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
128
+ layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
129
+
130
+ for _ in range(n_hidden_layers):
131
+ layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
132
+ layers.append(nn.ReLU(inplace=False))
133
+ layers.append(Conv2dZeros(hidden_channels, out_channels))
134
+
135
+ return nn.Sequential(*layers)
models/SRFlow/code/models/modules/FlowStep.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ from torch import nn as nn
19
+
20
+ import models.modules
21
+ import models.modules.Permutations
22
+ from models.modules import flow, thops, FlowAffineCouplingsAblation
23
+ from utils.util import opt_get
24
+
25
+
26
+ def getConditional(rrdbResults, position):
27
+ img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
28
+ return img_ft
29
+
30
+
31
+ class FlowStep(nn.Module):
32
+ FlowPermutation = {
33
+ "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
34
+ "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
35
+ "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
36
+ "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
37
+ "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
38
+ "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
39
+ "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
40
+ "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
41
+ "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
42
+ "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
43
+ }
44
+
45
+ def __init__(self, in_channels, hidden_channels,
46
+ actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
47
+ LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
48
+ position=None):
49
+ # check configures
50
+ assert flow_permutation in FlowStep.FlowPermutation, \
51
+ "float_permutation should be in `{}`".format(
52
+ FlowStep.FlowPermutation.keys())
53
+ super().__init__()
54
+ self.flow_permutation = flow_permutation
55
+ self.flow_coupling = flow_coupling
56
+ self.image_injector = image_injector
57
+
58
+ self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
59
+ self.position = normOpt['position'] if normOpt else None
60
+
61
+ self.in_shape = in_shape
62
+ self.position = position
63
+ self.acOpt = acOpt
64
+
65
+ # 1. actnorm
66
+ self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
67
+
68
+ # 2. permute
69
+ if flow_permutation == "invconv":
70
+ self.invconv = models.modules.Permutations.InvertibleConv1x1(
71
+ in_channels, LU_decomposed=LU_decomposed)
72
+
73
+ # 3. coupling
74
+ if flow_coupling == "CondAffineSeparatedAndCond":
75
+ self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
76
+ opt=opt)
77
+ elif flow_coupling == "noCoupling":
78
+ pass
79
+ else:
80
+ raise RuntimeError("coupling not Found:", flow_coupling)
81
+
82
+ def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
83
+ if not reverse:
84
+ return self.normal_flow(input, logdet, rrdbResults)
85
+ else:
86
+ return self.reverse_flow(input, logdet, rrdbResults)
87
+
88
+ def normal_flow(self, z, logdet, rrdbResults=None):
89
+ if self.flow_coupling == "bentIdentityPreAct":
90
+ z, logdet = self.bentIdentPar(z, logdet, reverse=False)
91
+
92
+ # 1. actnorm
93
+ if self.norm_type == "ConditionalActNormImageInjector":
94
+ img_ft = getConditional(rrdbResults, self.position)
95
+ z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
96
+ elif self.norm_type == "noNorm":
97
+ pass
98
+ else:
99
+ z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
100
+
101
+ # 2. permute
102
+ z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
103
+ self, z, logdet, False)
104
+
105
+ need_features = self.affine_need_features()
106
+
107
+ # 3. coupling
108
+ if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
109
+ img_ft = getConditional(rrdbResults, self.position)
110
+ z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
111
+ return z, logdet
112
+
113
+ def reverse_flow(self, z, logdet, rrdbResults=None):
114
+
115
+ need_features = self.affine_need_features()
116
+
117
+ # 1.coupling
118
+ if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
119
+ img_ft = getConditional(rrdbResults, self.position)
120
+ z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
121
+
122
+ # 2. permute
123
+ z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
124
+ self, z, logdet, True)
125
+
126
+ # 3. actnorm
127
+ z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
128
+
129
+ return z, logdet
130
+
131
+ def affine_need_features(self):
132
+ need_features = False
133
+ try:
134
+ need_features = self.affine.need_features
135
+ except:
136
+ pass
137
+ return need_features
models/SRFlow/code/models/modules/FlowUpsamplerNet.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn as nn
20
+
21
+ import models.modules.Split
22
+ from models.modules import flow, thops
23
+ from models.modules.Split import Split2d
24
+ from models.modules.glow_arch import f_conv2d_bias
25
+ from models.modules.FlowStep import FlowStep
26
+ from utils.util import opt_get
27
+
28
+
29
+ class FlowUpsamplerNet(nn.Module):
30
+ def __init__(self, image_shape, hidden_channels, K, L=None,
31
+ actnorm_scale=1.0,
32
+ flow_permutation=None,
33
+ flow_coupling="affine",
34
+ LU_decomposed=False, opt=None):
35
+
36
+ super().__init__()
37
+
38
+ self.layers = nn.ModuleList()
39
+ self.output_shapes = []
40
+ self.L = opt_get(opt, ['network_G', 'flow', 'L'])
41
+ self.K = opt_get(opt, ['network_G', 'flow', 'K'])
42
+ if isinstance(self.K, int):
43
+ self.K = [K for K in [K, ] * (self.L + 1)]
44
+
45
+ self.opt = opt
46
+ H, W, self.C = image_shape
47
+ self.check_image_shape()
48
+
49
+ if opt['scale'] == 16:
50
+ self.levelToName = {
51
+ 0: 'fea_up16',
52
+ 1: 'fea_up8',
53
+ 2: 'fea_up4',
54
+ 3: 'fea_up2',
55
+ 4: 'fea_up1',
56
+ }
57
+
58
+ if opt['scale'] == 8:
59
+ self.levelToName = {
60
+ 0: 'fea_up8',
61
+ 1: 'fea_up4',
62
+ 2: 'fea_up2',
63
+ 3: 'fea_up1',
64
+ 4: 'fea_up0'
65
+ }
66
+
67
+ elif opt['scale'] == 4:
68
+ self.levelToName = {
69
+ 0: 'fea_up4',
70
+ 1: 'fea_up2',
71
+ 2: 'fea_up1',
72
+ 3: 'fea_up0',
73
+ 4: 'fea_up-1'
74
+ }
75
+
76
+ affineInCh = self.get_affineInCh(opt_get)
77
+ flow_permutation = self.get_flow_permutation(flow_permutation, opt)
78
+
79
+ normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
80
+
81
+ conditional_channels = {}
82
+ n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
83
+ n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
84
+ conditional_channels[0] = n_rrdb
85
+ for level in range(1, self.L + 1):
86
+ # Level 1 gets conditionals from 2, 3, 4 => L - level
87
+ # Level 2 gets conditionals from 3, 4
88
+ # Level 3 gets conditionals from 4
89
+ # Level 4 gets conditionals from None
90
+ n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
91
+ conditional_channels[level] = n_rrdb + n_bypass
92
+
93
+ # Upsampler
94
+ for level in range(1, self.L + 1):
95
+ # 1. Squeeze
96
+ H, W = self.arch_squeeze(H, W)
97
+
98
+ # 2. K FlowStep
99
+ self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
100
+ self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
101
+ flow_permutation,
102
+ hidden_channels, normOpt, opt, opt_get,
103
+ n_conditinal_channels=conditional_channels[level])
104
+ # Split
105
+ self.arch_split(H, W, level, self.L, opt, opt_get)
106
+
107
+ if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
108
+ self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
109
+ else:
110
+ self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
111
+
112
+ self.H = H
113
+ self.W = W
114
+ self.scaleH = 160 / H
115
+ self.scaleW = 160 / W
116
+
117
+ def get_n_rrdb_channels(self, opt, opt_get):
118
+ blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
119
+ n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
120
+ return n_rrdb
121
+
122
+ def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
123
+ hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
124
+ condAff = self.get_condAffSetting(opt, opt_get)
125
+ if condAff is not None:
126
+ condAff['in_channels_rrdb'] = n_conditinal_channels
127
+
128
+ for k in range(K):
129
+ position_name = get_position_name(H, self.opt['scale'])
130
+ if normOpt: normOpt['position'] = position_name
131
+
132
+ self.layers.append(
133
+ FlowStep(in_channels=self.C,
134
+ hidden_channels=hidden_channels,
135
+ actnorm_scale=actnorm_scale,
136
+ flow_permutation=flow_permutation,
137
+ flow_coupling=flow_coupling,
138
+ acOpt=condAff,
139
+ position=position_name,
140
+ LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
141
+ self.output_shapes.append(
142
+ [-1, self.C, H, W])
143
+
144
+ def get_condAffSetting(self, opt, opt_get):
145
+ condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
146
+ condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
147
+ return condAff
148
+
149
+ def arch_split(self, H, W, L, levels, opt, opt_get):
150
+ correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
151
+ correction = 0 if correct_splits else 1
152
+ if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
153
+ logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
154
+ consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
155
+ position_name = get_position_name(H, self.opt['scale'])
156
+ position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
157
+ cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
158
+ cond_channels = 0 if cond_channels is None else cond_channels
159
+
160
+ t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
161
+
162
+ if t == 'Split2d':
163
+ split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
164
+ cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
165
+ self.layers.append(split)
166
+ self.output_shapes.append([-1, split.num_channels_pass, H, W])
167
+ self.C = split.num_channels_pass
168
+
169
+ def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
170
+ if 'additionalFlowNoAffine' in opt['network_G']['flow']:
171
+ n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
172
+ for _ in range(n_additionalFlowNoAffine):
173
+ self.layers.append(
174
+ FlowStep(in_channels=self.C,
175
+ hidden_channels=hidden_channels,
176
+ actnorm_scale=actnorm_scale,
177
+ flow_permutation='invconv',
178
+ flow_coupling='noCoupling',
179
+ LU_decomposed=LU_decomposed, opt=opt))
180
+ self.output_shapes.append(
181
+ [-1, self.C, H, W])
182
+
183
+ def arch_squeeze(self, H, W):
184
+ self.C, H, W = self.C * 4, H // 2, W // 2
185
+ self.layers.append(flow.SqueezeLayer(factor=2))
186
+ self.output_shapes.append([-1, self.C, H, W])
187
+ return H, W
188
+
189
+ def get_flow_permutation(self, flow_permutation, opt):
190
+ flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
191
+ return flow_permutation
192
+
193
+ def get_affineInCh(self, opt_get):
194
+ affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
195
+ affineInCh = (len(affineInCh) + 1) * 64
196
+ return affineInCh
197
+
198
+ def check_image_shape(self):
199
+ assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
200
+ "self.C == 1 or self.C == 3")
201
+
202
+ def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
203
+ y_onehot=None):
204
+
205
+ if reverse:
206
+ epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
207
+
208
+ sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
209
+ return sr, logdet
210
+ else:
211
+ assert gt is not None
212
+ assert rrdbResults is not None
213
+ z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
214
+
215
+ return z, logdet
216
+
217
+ def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
218
+ fl_fea = gt
219
+ reverse = False
220
+ level_conditionals = {}
221
+ bypasses = {}
222
+
223
+ L = opt_get(self.opt, ['network_G', 'flow', 'L'])
224
+
225
+ for level in range(1, L + 1):
226
+ bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
227
+
228
+ for layer, shape in zip(self.layers, self.output_shapes):
229
+ size = shape[2]
230
+ level = int(np.log(160 / size) / np.log(2))
231
+
232
+ if level > 0 and level not in level_conditionals.keys():
233
+ level_conditionals[level] = rrdbResults[self.levelToName[level]]
234
+
235
+ level_conditionals[level] = rrdbResults[self.levelToName[level]]
236
+
237
+ if isinstance(layer, FlowStep):
238
+ fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
239
+ elif isinstance(layer, Split2d):
240
+ fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
241
+ y_onehot=y_onehot)
242
+ else:
243
+ fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
244
+
245
+ z = fl_fea
246
+
247
+ if not isinstance(epses, list):
248
+ return z, logdet
249
+
250
+ epses.append(z)
251
+ return epses, logdet
252
+
253
+ def forward_preFlow(self, fl_fea, logdet, reverse):
254
+ if hasattr(self, 'preFlow'):
255
+ for l in self.preFlow:
256
+ fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
257
+ return fl_fea, logdet
258
+
259
+ def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
260
+ ft = None if layer.position is None else rrdbResults[layer.position]
261
+ fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
262
+
263
+ if isinstance(epses, list):
264
+ epses.append(eps)
265
+ return fl_fea, logdet
266
+
267
+ def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
268
+ z = epses.pop() if isinstance(epses, list) else z
269
+
270
+ fl_fea = z
271
+ # debug.imwrite("fl_fea", fl_fea)
272
+ bypasses = {}
273
+ level_conditionals = {}
274
+ if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
275
+ for level in range(self.L + 1):
276
+ level_conditionals[level] = rrdbResults[self.levelToName[level]]
277
+
278
+ for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
279
+ size = shape[2]
280
+ level = int(np.log(160 / size) / np.log(2))
281
+ # size = fl_fea.shape[2]
282
+ # level = int(np.log(160 / size) / np.log(2))
283
+
284
+ if isinstance(layer, Split2d):
285
+ fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
286
+ rrdbResults[self.levelToName[level]], logdet=logdet,
287
+ y_onehot=y_onehot)
288
+ elif isinstance(layer, FlowStep):
289
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
290
+ else:
291
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
292
+
293
+ sr = fl_fea
294
+
295
+ assert sr.shape[1] == 3
296
+ return sr, logdet
297
+
298
+ def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
299
+ ft = None if layer.position is None else rrdbResults[layer.position]
300
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
301
+ eps=epses.pop() if isinstance(epses, list) else None,
302
+ eps_std=eps_std, ft=ft, y_onehot=y_onehot)
303
+ return fl_fea, logdet
304
+
305
+
306
+ def get_position_name(H, scale):
307
+ downscale_factor = 160 // H
308
+ position_name = 'fea_up{}'.format(scale / downscale_factor)
309
+ return position_name
models/SRFlow/code/models/modules/Permutations.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn as nn
20
+ from torch.nn import functional as F
21
+
22
+ from models.modules import thops
23
+
24
+
25
+ class InvertibleConv1x1(nn.Module):
26
+ def __init__(self, num_channels, LU_decomposed=False):
27
+ super().__init__()
28
+ w_shape = [num_channels, num_channels]
29
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
30
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
31
+ self.w_shape = w_shape
32
+ self.LU = LU_decomposed
33
+
34
+ def get_weight(self, input, reverse):
35
+ w_shape = self.w_shape
36
+ pixels = thops.pixels(input)
37
+ dlogdet = torch.slogdet(self.weight)[1] * pixels
38
+ if not reverse:
39
+ weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
40
+ else:
41
+ weight = torch.inverse(self.weight.double()).float() \
42
+ .view(w_shape[0], w_shape[1], 1, 1)
43
+ return weight, dlogdet
44
+ def forward(self, input, logdet=None, reverse=False):
45
+ """
46
+ log-det = log|abs(|W|)| * pixels
47
+ """
48
+ weight, dlogdet = self.get_weight(input, reverse)
49
+ if not reverse:
50
+ z = F.conv2d(input, weight)
51
+ if logdet is not None:
52
+ logdet = logdet + dlogdet
53
+ return z, logdet
54
+ else:
55
+ z = F.conv2d(input, weight)
56
+ if logdet is not None:
57
+ logdet = logdet - dlogdet
58
+ return z, logdet
models/SRFlow/code/models/modules/RRDBNet_arch.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import functools
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import models.modules.module_util as mutil
22
+ from utils.util import opt_get
23
+
24
+
25
+ class ResidualDenseBlock_5C(nn.Module):
26
+ def __init__(self, nf=64, gc=32, bias=True):
27
+ super(ResidualDenseBlock_5C, self).__init__()
28
+ # gc: growth channel, i.e. intermediate channels
29
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
30
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
31
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
32
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
33
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
34
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
35
+
36
+ # initialization
37
+ mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
38
+
39
+ def forward(self, x):
40
+ x1 = self.lrelu(self.conv1(x))
41
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
42
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
43
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
44
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
45
+ return x5 * 0.2 + x
46
+
47
+
48
+ class RRDB(nn.Module):
49
+ '''Residual in Residual Dense Block'''
50
+
51
+ def __init__(self, nf, gc=32):
52
+ super(RRDB, self).__init__()
53
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
54
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
55
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
56
+
57
+ def forward(self, x):
58
+ out = self.RDB1(x)
59
+ out = self.RDB2(out)
60
+ out = self.RDB3(out)
61
+ return out * 0.2 + x
62
+
63
+
64
+ class RRDBNet(nn.Module):
65
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
66
+ self.opt = opt
67
+ super(RRDBNet, self).__init__()
68
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
69
+ self.scale = scale
70
+
71
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
72
+ self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
73
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
74
+ #### upsampling
75
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
76
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
77
+ if self.scale >= 8:
78
+ self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
79
+ if self.scale >= 16:
80
+ self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
81
+ if self.scale >= 32:
82
+ self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83
+
84
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
86
+
87
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
88
+
89
+ def forward(self, x, get_steps=False):
90
+ fea = self.conv_first(x)
91
+
92
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
93
+ block_results = {}
94
+
95
+ for idx, m in enumerate(self.RRDB_trunk.children()):
96
+ fea = m(fea)
97
+ for b in block_idxs:
98
+ if b == idx:
99
+ block_results["block_{}".format(idx)] = fea
100
+
101
+ trunk = self.trunk_conv(fea)
102
+
103
+ last_lr_fea = fea + trunk
104
+
105
+ fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
106
+ fea = self.lrelu(fea_up2)
107
+
108
+ fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
109
+ fea = self.lrelu(fea_up4)
110
+
111
+ fea_up8 = None
112
+ fea_up16 = None
113
+ fea_up32 = None
114
+
115
+ if self.scale >= 8:
116
+ fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
117
+ fea = self.lrelu(fea_up8)
118
+ if self.scale >= 16:
119
+ fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
120
+ fea = self.lrelu(fea_up16)
121
+ if self.scale >= 32:
122
+ fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
123
+ fea = self.lrelu(fea_up32)
124
+
125
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
126
+
127
+ results = {'last_lr_fea': last_lr_fea,
128
+ 'fea_up1': last_lr_fea,
129
+ 'fea_up2': fea_up2,
130
+ 'fea_up4': fea_up4,
131
+ 'fea_up8': fea_up8,
132
+ 'fea_up16': fea_up16,
133
+ 'fea_up32': fea_up32,
134
+ 'out': out}
135
+
136
+ fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
137
+ if fea_up0_en:
138
+ results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
139
+ fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
140
+ if fea_upn1_en:
141
+ results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
142
+
143
+ if get_steps:
144
+ for k, v in block_results.items():
145
+ results[k] = v
146
+ return results
147
+ else:
148
+ return out
models/SRFlow/code/models/modules/SRFlowNet_arch.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import math
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+ from models.modules.RRDBNet_arch import RRDBNet
24
+ from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
25
+ import models.modules.thops as thops
26
+ import models.modules.flow as flow
27
+ from utils.util import opt_get
28
+
29
+
30
+ class SRFlowNet(nn.Module):
31
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
32
+ super(SRFlowNet, self).__init__()
33
+
34
+ self.opt = opt
35
+ self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
36
+ None else opt_get(opt, ['datasets', 'train', 'quant'])
37
+ self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
38
+ hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
39
+ hidden_channels = hidden_channels or 64
40
+ self.RRDB_training = True # Default is true
41
+
42
+ train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
43
+ set_RRDB_to_train = False
44
+ if set_RRDB_to_train:
45
+ self.set_rrdb_training(True)
46
+
47
+ self.flowUpsamplerNet = \
48
+ FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
49
+ flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
50
+ self.i = 0
51
+
52
+ def set_rrdb_training(self, trainable):
53
+ if self.RRDB_training != trainable:
54
+ for p in self.RRDB.parameters():
55
+ p.requires_grad = trainable
56
+ self.RRDB_training = trainable
57
+ return True
58
+ return False
59
+
60
+ def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
61
+ lr_enc=None,
62
+ add_gt_noise=False, step=None, y_label=None):
63
+ if not reverse:
64
+ return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
65
+ y_onehot=y_label)
66
+ else:
67
+ # assert lr.shape[0] == 1
68
+ assert lr.shape[1] == 3
69
+ # assert lr.shape[2] == 20
70
+ # assert lr.shape[3] == 20
71
+ # assert z.shape[0] == 1
72
+ # assert z.shape[1] == 3 * 8 * 8
73
+ # assert z.shape[2] == 20
74
+ # assert z.shape[3] == 20
75
+ if reverse_with_grad:
76
+ return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
77
+ add_gt_noise=add_gt_noise)
78
+ else:
79
+ with torch.no_grad():
80
+ return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
81
+ add_gt_noise=add_gt_noise)
82
+
83
+ def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
84
+ if lr_enc is None:
85
+ lr_enc = self.rrdbPreprocessing(lr)
86
+
87
+ logdet = torch.zeros_like(gt[:, 0, 0, 0])
88
+ pixels = thops.pixels(gt)
89
+
90
+ z = gt
91
+
92
+ if add_gt_noise:
93
+ # Setup
94
+ noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
95
+ if noiseQuant:
96
+ z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
97
+ logdet = logdet + float(-np.log(self.quant) * pixels)
98
+
99
+ # Encode
100
+ epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
101
+ y_onehot=y_onehot)
102
+
103
+ objective = logdet.clone()
104
+
105
+ if isinstance(epses, (list, tuple)):
106
+ z = epses[-1]
107
+ else:
108
+ z = epses
109
+
110
+ objective = objective + flow.GaussianDiag.logp(None, None, z)
111
+
112
+ nll = (-objective) / float(np.log(2.) * pixels)
113
+
114
+ if isinstance(epses, list):
115
+ return epses, nll, logdet
116
+ return z, nll, logdet
117
+
118
+ def rrdbPreprocessing(self, lr):
119
+ rrdbResults = self.RRDB(lr, get_steps=True)
120
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
121
+ if len(block_idxs) > 0:
122
+ concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
123
+
124
+ if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
125
+ keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
126
+ if 'fea_up0' in rrdbResults.keys():
127
+ keys.append('fea_up0')
128
+ if 'fea_up-1' in rrdbResults.keys():
129
+ keys.append('fea_up-1')
130
+ if self.opt['scale'] >= 8:
131
+ keys.append('fea_up8')
132
+ if self.opt['scale'] == 16:
133
+ keys.append('fea_up16')
134
+ for k in keys:
135
+ h = rrdbResults[k].shape[2]
136
+ w = rrdbResults[k].shape[3]
137
+ rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
138
+ return rrdbResults
139
+
140
+ def get_score(self, disc_loss_sigma, z):
141
+ score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
142
+ z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
143
+ return -score_real
144
+
145
+ def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
146
+ logdet = torch.zeros_like(lr[:, 0, 0, 0])
147
+ pixels = thops.pixels(lr) * self.opt['scale'] ** 2
148
+
149
+ if add_gt_noise:
150
+ logdet = logdet - float(-np.log(self.quant) * pixels)
151
+
152
+ if lr_enc is None:
153
+ lr_enc = self.rrdbPreprocessing(lr)
154
+
155
+ x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
156
+ logdet=logdet)
157
+
158
+ return x, logdet
models/SRFlow/code/models/modules/Split.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ from torch import nn as nn
19
+
20
+ from models.modules import thops
21
+ from models.modules.FlowStep import FlowStep
22
+ from models.modules.flow import Conv2dZeros, GaussianDiag
23
+ from utils.util import opt_get
24
+
25
+
26
+ class Split2d(nn.Module):
27
+ def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
28
+ super().__init__()
29
+
30
+ self.num_channels_consume = int(round(num_channels * consume_ratio))
31
+ self.num_channels_pass = num_channels - self.num_channels_consume
32
+
33
+ self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
34
+ out_channels=self.num_channels_consume * 2)
35
+ self.logs_eps = logs_eps
36
+ self.position = position
37
+ self.opt = opt
38
+
39
+ def split2d_prior(self, z, ft):
40
+ if ft is not None:
41
+ z = torch.cat([z, ft], dim=1)
42
+ h = self.conv(z)
43
+ return thops.split_feature(h, "cross")
44
+
45
+ def exp_eps(self, logs):
46
+ return torch.exp(logs) + self.logs_eps
47
+
48
+ def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
49
+ if not reverse:
50
+ # self.input = input
51
+ z1, z2 = self.split_ratio(input)
52
+ mean, logs = self.split2d_prior(z1, ft)
53
+
54
+ eps = (z2 - mean) / self.exp_eps(logs)
55
+
56
+ logdet = logdet + self.get_logdet(logs, mean, z2)
57
+
58
+ # print(logs.shape, mean.shape, z2.shape)
59
+ # self.eps = eps
60
+ # print('split, enc eps:', eps)
61
+ return z1, logdet, eps
62
+ else:
63
+ z1 = input
64
+ mean, logs = self.split2d_prior(z1, ft)
65
+
66
+ if eps is None:
67
+ #print("WARNING: eps is None, generating eps untested functionality!")
68
+ eps = GaussianDiag.sample_eps(mean.shape, eps_std)
69
+
70
+ eps = eps.to(mean.device)
71
+ z2 = mean + self.exp_eps(logs) * eps
72
+
73
+ z = thops.cat_feature(z1, z2)
74
+ logdet = logdet - self.get_logdet(logs, mean, z2)
75
+
76
+ return z, logdet
77
+ # return z, logdet, eps
78
+
79
+ def get_logdet(self, logs, mean, z2):
80
+ logdet_diff = GaussianDiag.logp(mean, logs, z2)
81
+ # print("Split2D: logdet diff", logdet_diff.item())
82
+ return logdet_diff
83
+
84
+ def split_ratio(self, input):
85
+ z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
86
+ return z1, z2
models/SRFlow/code/models/modules/__init__.py ADDED
File without changes
models/SRFlow/code/models/modules/flow.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import numpy as np
21
+
22
+ from models.modules.FlowActNorms import ActNorm2d
23
+ from . import thops
24
+
25
+
26
+ class Conv2d(nn.Conv2d):
27
+ pad_dict = {
28
+ "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
29
+ "valid": lambda kernel, stride: [0 for _ in kernel]
30
+ }
31
+
32
+ @staticmethod
33
+ def get_padding(padding, kernel_size, stride):
34
+ # make paddding
35
+ if isinstance(padding, str):
36
+ if isinstance(kernel_size, int):
37
+ kernel_size = [kernel_size, kernel_size]
38
+ if isinstance(stride, int):
39
+ stride = [stride, stride]
40
+ padding = padding.lower()
41
+ try:
42
+ padding = Conv2d.pad_dict[padding](kernel_size, stride)
43
+ except KeyError:
44
+ raise ValueError("{} is not supported".format(padding))
45
+ return padding
46
+
47
+ def __init__(self, in_channels, out_channels,
48
+ kernel_size=[3, 3], stride=[1, 1],
49
+ padding="same", do_actnorm=True, weight_std=0.05):
50
+ padding = Conv2d.get_padding(padding, kernel_size, stride)
51
+ super().__init__(in_channels, out_channels, kernel_size, stride,
52
+ padding, bias=(not do_actnorm))
53
+ # init weight with std
54
+ self.weight.data.normal_(mean=0.0, std=weight_std)
55
+ if not do_actnorm:
56
+ self.bias.data.zero_()
57
+ else:
58
+ self.actnorm = ActNorm2d(out_channels)
59
+ self.do_actnorm = do_actnorm
60
+
61
+ def forward(self, input):
62
+ x = super().forward(input)
63
+ if self.do_actnorm:
64
+ x, _ = self.actnorm(x)
65
+ return x
66
+
67
+
68
+ class Conv2dZeros(nn.Conv2d):
69
+ def __init__(self, in_channels, out_channels,
70
+ kernel_size=[3, 3], stride=[1, 1],
71
+ padding="same", logscale_factor=3):
72
+ padding = Conv2d.get_padding(padding, kernel_size, stride)
73
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding)
74
+ # logscale_factor
75
+ self.logscale_factor = logscale_factor
76
+ self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
77
+ # init
78
+ self.weight.data.zero_()
79
+ self.bias.data.zero_()
80
+
81
+ def forward(self, input):
82
+ output = super().forward(input)
83
+ return output * torch.exp(self.logs * self.logscale_factor)
84
+
85
+
86
+ class GaussianDiag:
87
+ Log2PI = float(np.log(2 * np.pi))
88
+
89
+ @staticmethod
90
+ def likelihood(mean, logs, x):
91
+ """
92
+ lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
93
+ k = 1 (Independent)
94
+ Var = logs ** 2
95
+ """
96
+ if mean is None and logs is None:
97
+ return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
98
+ else:
99
+ return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
100
+
101
+ @staticmethod
102
+ def logp(mean, logs, x):
103
+ likelihood = GaussianDiag.likelihood(mean, logs, x)
104
+ return thops.sum(likelihood, dim=[1, 2, 3])
105
+
106
+ @staticmethod
107
+ def sample(mean, logs, eps_std=None):
108
+ eps_std = eps_std or 1
109
+ eps = torch.normal(mean=torch.zeros_like(mean),
110
+ std=torch.ones_like(logs) * eps_std)
111
+ return mean + torch.exp(logs) * eps
112
+
113
+ @staticmethod
114
+ def sample_eps(shape, eps_std, seed=None):
115
+ if seed is not None:
116
+ torch.manual_seed(seed)
117
+ eps = torch.normal(mean=torch.zeros(shape),
118
+ std=torch.ones(shape) * eps_std)
119
+ return eps
120
+
121
+
122
+ def squeeze2d(input, factor=2):
123
+ assert factor >= 1 and isinstance(factor, int)
124
+ if factor == 1:
125
+ return input
126
+ size = input.size()
127
+ B = size[0]
128
+ C = size[1]
129
+ H = size[2]
130
+ W = size[3]
131
+ assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
132
+ x = input.view(B, C, H // factor, factor, W // factor, factor)
133
+ x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
134
+ x = x.view(B, C * factor * factor, H // factor, W // factor)
135
+ return x
136
+
137
+
138
+ def unsqueeze2d(input, factor=2):
139
+ assert factor >= 1 and isinstance(factor, int)
140
+ factor2 = factor ** 2
141
+ if factor == 1:
142
+ return input
143
+ size = input.size()
144
+ B = size[0]
145
+ C = size[1]
146
+ H = size[2]
147
+ W = size[3]
148
+ assert C % (factor2) == 0, "{}".format(C)
149
+ x = input.view(B, C // factor2, factor, factor, H, W)
150
+ x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
151
+ x = x.view(B, C // (factor2), H * factor, W * factor)
152
+ return x
153
+
154
+
155
+ class SqueezeLayer(nn.Module):
156
+ def __init__(self, factor):
157
+ super().__init__()
158
+ self.factor = factor
159
+
160
+ def forward(self, input, logdet=None, reverse=False):
161
+ if not reverse:
162
+ output = squeeze2d(input, self.factor) # Squeeze in forward
163
+ return output, logdet
164
+ else:
165
+ output = unsqueeze2d(input, self.factor)
166
+ return output, logdet
models/SRFlow/code/models/modules/glow_arch.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch.nn as nn
18
+
19
+
20
+ def f_conv2d_bias(in_channels, out_channels):
21
+ def padding_same(kernel, stride):
22
+ return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
23
+
24
+ padding = padding_same([3, 3], [1, 1])
25
+ assert padding == [1, 1], padding
26
+ return nn.Sequential(
27
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
28
+ bias=True))
models/SRFlow/code/models/modules/loss.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class CharbonnierLoss(nn.Module):
22
+ """Charbonnier Loss (L1)"""
23
+
24
+ def __init__(self, eps=1e-6):
25
+ super(CharbonnierLoss, self).__init__()
26
+ self.eps = eps
27
+
28
+ def forward(self, x, y):
29
+ diff = x - y
30
+ loss = torch.sum(torch.sqrt(diff * diff + self.eps))
31
+ return loss
32
+
33
+
34
+ # Define GAN loss: [vanilla | lsgan | wgan-gp]
35
+ class GANLoss(nn.Module):
36
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
37
+ super(GANLoss, self).__init__()
38
+ self.gan_type = gan_type.lower()
39
+ self.real_label_val = real_label_val
40
+ self.fake_label_val = fake_label_val
41
+
42
+ if self.gan_type == 'gan' or self.gan_type == 'ragan':
43
+ self.loss = nn.BCEWithLogitsLoss()
44
+ elif self.gan_type == 'lsgan':
45
+ self.loss = nn.MSELoss()
46
+ elif self.gan_type == 'wgan-gp':
47
+
48
+ def wgan_loss(input, target):
49
+ # target is boolean
50
+ return -1 * input.mean() if target else input.mean()
51
+
52
+ self.loss = wgan_loss
53
+ else:
54
+ raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
55
+
56
+ def get_target_label(self, input, target_is_real):
57
+ if self.gan_type == 'wgan-gp':
58
+ return target_is_real
59
+ if target_is_real:
60
+ return torch.empty_like(input).fill_(self.real_label_val)
61
+ else:
62
+ return torch.empty_like(input).fill_(self.fake_label_val)
63
+
64
+ def forward(self, input, target_is_real):
65
+ target_label = self.get_target_label(input, target_is_real)
66
+ loss = self.loss(input, target_label)
67
+ return loss
68
+
69
+
70
+ class GradientPenaltyLoss(nn.Module):
71
+ def __init__(self, device=torch.device('cpu')):
72
+ super(GradientPenaltyLoss, self).__init__()
73
+ self.register_buffer('grad_outputs', torch.Tensor())
74
+ self.grad_outputs = self.grad_outputs.to(device)
75
+
76
+ def get_grad_outputs(self, input):
77
+ if self.grad_outputs.size() != input.size():
78
+ self.grad_outputs.resize_(input.size()).fill_(1.0)
79
+ return self.grad_outputs
80
+
81
+ def forward(self, interp, interp_crit):
82
+ grad_outputs = self.get_grad_outputs(interp_crit)
83
+ grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
84
+ grad_outputs=grad_outputs, create_graph=True,
85
+ retain_graph=True, only_inputs=True)[0]
86
+ grad_interp = grad_interp.view(grad_interp.size(0), -1)
87
+ grad_interp_norm = grad_interp.norm(2, dim=1)
88
+
89
+ loss = ((grad_interp_norm - 1)**2).mean()
90
+ return loss
models/SRFlow/code/models/modules/module_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.init as init
20
+ import torch.nn.functional as F
21
+
22
+
23
+ def initialize_weights(net_l, scale=1):
24
+ if not isinstance(net_l, list):
25
+ net_l = [net_l]
26
+ for net in net_l:
27
+ for m in net.modules():
28
+ if isinstance(m, nn.Conv2d):
29
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
30
+ m.weight.data *= scale # for residual block
31
+ if m.bias is not None:
32
+ m.bias.data.zero_()
33
+ elif isinstance(m, nn.Linear):
34
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
35
+ m.weight.data *= scale
36
+ if m.bias is not None:
37
+ m.bias.data.zero_()
38
+ elif isinstance(m, nn.BatchNorm2d):
39
+ init.constant_(m.weight, 1)
40
+ init.constant_(m.bias.data, 0.0)
41
+
42
+
43
+ def make_layer(block, n_layers):
44
+ layers = []
45
+ for _ in range(n_layers):
46
+ layers.append(block())
47
+ return nn.Sequential(*layers)
48
+
49
+
50
+ class ResidualBlock_noBN(nn.Module):
51
+ '''Residual block w/o BN
52
+ ---Conv-ReLU-Conv-+-
53
+ |________________|
54
+ '''
55
+
56
+ def __init__(self, nf=64):
57
+ super(ResidualBlock_noBN, self).__init__()
58
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
60
+
61
+ # initialization
62
+ initialize_weights([self.conv1, self.conv2], 0.1)
63
+
64
+ def forward(self, x):
65
+ identity = x
66
+ out = F.relu(self.conv1(x), inplace=True)
67
+ out = self.conv2(out)
68
+ return identity + out
69
+
70
+
71
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
72
+ """Warp an image or feature map with optical flow
73
+ Args:
74
+ x (Tensor): size (N, C, H, W)
75
+ flow (Tensor): size (N, H, W, 2), normal value
76
+ interp_mode (str): 'nearest' or 'bilinear'
77
+ padding_mode (str): 'zeros' or 'border' or 'reflection'
78
+
79
+ Returns:
80
+ Tensor: warped image or feature map
81
+ """
82
+ assert x.size()[-2:] == flow.size()[1:3]
83
+ B, C, H, W = x.size()
84
+ # mesh grid
85
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
86
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
87
+ grid.requires_grad = False
88
+ grid = grid.type_as(x)
89
+ vgrid = grid + flow
90
+ # scale grid to [-1,1]
91
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
92
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
93
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
94
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
95
+ return output
models/SRFlow/code/models/modules/thops.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16
+
17
+ import torch
18
+
19
+
20
+ def sum(tensor, dim=None, keepdim=False):
21
+ if dim is None:
22
+ # sum up all dim
23
+ return torch.sum(tensor)
24
+ else:
25
+ if isinstance(dim, int):
26
+ dim = [dim]
27
+ dim = sorted(dim)
28
+ for d in dim:
29
+ tensor = tensor.sum(dim=d, keepdim=True)
30
+ if not keepdim:
31
+ for i, d in enumerate(dim):
32
+ tensor.squeeze_(d-i)
33
+ return tensor
34
+
35
+
36
+ def mean(tensor, dim=None, keepdim=False):
37
+ if dim is None:
38
+ # mean all dim
39
+ return torch.mean(tensor)
40
+ else:
41
+ if isinstance(dim, int):
42
+ dim = [dim]
43
+ dim = sorted(dim)
44
+ for d in dim:
45
+ tensor = tensor.mean(dim=d, keepdim=True)
46
+ if not keepdim:
47
+ for i, d in enumerate(dim):
48
+ tensor.squeeze_(d-i)
49
+ return tensor
50
+
51
+
52
+ def split_feature(tensor, type="split"):
53
+ """
54
+ type = ["split", "cross"]
55
+ """
56
+ C = tensor.size(1)
57
+ if type == "split":
58
+ return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
59
+ elif type == "cross":
60
+ return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
61
+
62
+
63
+ def cat_feature(tensor_a, tensor_b):
64
+ return torch.cat((tensor_a, tensor_b), dim=1)
65
+
66
+
67
+ def pixels(tensor):
68
+ return int(tensor.size(2) * tensor.size(3))
models/SRFlow/code/models/networks.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import importlib
18
+
19
+ import torch
20
+ import logging
21
+ import models.modules.RRDBNet_arch as RRDBNet_arch
22
+
23
+ logger = logging.getLogger('base')
24
+
25
+
26
+ def find_model_using_name(model_name):
27
+ model_filename = "models.modules." + model_name + "_arch"
28
+ modellib = importlib.import_module(model_filename)
29
+
30
+ model = None
31
+ target_model_name = model_name.replace('_Net', '')
32
+ for name, cls in modellib.__dict__.items():
33
+ if name.lower() == target_model_name.lower():
34
+ model = cls
35
+
36
+ if model is None:
37
+ print(
38
+ "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
39
+ model_filename, target_model_name))
40
+ exit(0)
41
+
42
+ return model
43
+
44
+
45
+ ####################
46
+ # define network
47
+ ####################
48
+ #### Generator
49
+ def define_G(opt):
50
+ opt_net = opt['network_G']
51
+ which_model = opt_net['which_model_G']
52
+
53
+ if which_model == 'RRDBNet':
54
+ netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
55
+ nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], opt=opt)
56
+ elif which_model == 'EDSRNet':
57
+ Arch = find_model_using_name(which_model)
58
+ netG = Arch(scale=opt['scale'])
59
+ elif which_model == 'rankSRGAN':
60
+ Arch = find_model_using_name(which_model)
61
+ netG = Arch(upscale=opt['scale'])
62
+ # elif which_model == 'sft_arch': # SFT-GAN
63
+ # netG = sft_arch.SFT_Net()
64
+ else:
65
+ raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
66
+ return netG
67
+
68
+
69
+ def define_Flow(opt, step):
70
+ opt_net = opt['network_G']
71
+ which_model = opt_net['which_model_G']
72
+
73
+ Arch = find_model_using_name(which_model)
74
+ netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
75
+ nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step)
76
+
77
+ return netG
78
+
79
+
80
+ #### Discriminator
81
+ def define_D(opt):
82
+ opt_net = opt['network_D']
83
+ which_model = opt_net['which_model_D']
84
+
85
+ if which_model == 'discriminator_vgg_128':
86
+ hidden_units = opt_net.get('hidden_units', 8192)
87
+ netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], hidden_units=hidden_units)
88
+ else:
89
+ raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
90
+ return netD
91
+
92
+
93
+ #### Define Network used for Perceptual Loss
94
+ def define_F(opt, use_bn=False):
95
+ gpu_ids = opt.get('gpu_ids', None)
96
+ device = torch.device('cuda' if gpu_ids else 'cpu')
97
+ # PyTorch pretrained_models VGG19-54, before ReLU.
98
+ if use_bn:
99
+ feature_layer = 49
100
+ else:
101
+ feature_layer = 34
102
+ netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
103
+ use_input_norm=True, device=device)
104
+ netF.eval() # No need to train
105
+ return netF
models/SRFlow/code/options/__init__.py ADDED
File without changes
models/SRFlow/code/options/options.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ import os.path as osp
19
+ import logging
20
+ import yaml
21
+ from utils.util import OrderedYaml
22
+
23
+ Loader, Dumper = OrderedYaml()
24
+
25
+
26
+ def parse(opt_path, is_train=True):
27
+ with open(opt_path, mode='r') as f:
28
+ opt = yaml.load(f, Loader=Loader)
29
+ # export CUDA_VISIBLE_DEVICES
30
+ gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
31
+ # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
32
+ # print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
33
+ opt['is_train'] = is_train
34
+ if opt['distortion'] == 'sr':
35
+ scale = opt['scale']
36
+
37
+ # datasets
38
+ for phase, dataset in opt['datasets'].items():
39
+ phase = phase.split('_')[0]
40
+ dataset['phase'] = phase
41
+ if opt['distortion'] == 'sr':
42
+ dataset['scale'] = scale
43
+ is_lmdb = False
44
+ if dataset.get('dataroot_GT', None) is not None:
45
+ dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
46
+ if dataset['dataroot_GT'].endswith('lmdb'):
47
+ is_lmdb = True
48
+ if dataset.get('dataroot_LQ', None) is not None:
49
+ dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
50
+ if dataset['dataroot_LQ'].endswith('lmdb'):
51
+ is_lmdb = True
52
+ dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
53
+ if dataset['mode'].endswith('mc'): # for memcached
54
+ dataset['data_type'] = 'mc'
55
+ dataset['mode'] = dataset['mode'].replace('_mc', '')
56
+
57
+ # path
58
+ for key, path in opt['path'].items():
59
+ if path and key in opt['path'] and key != 'strict_load':
60
+ opt['path'][key] = osp.expanduser(path)
61
+ opt['path']['root'] = '/kaggle/working/'
62
+ if is_train:
63
+ experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
64
+ opt['path']['experiments_root'] = experiments_root
65
+ opt['path']['models'] = osp.join(experiments_root, 'models')
66
+ opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
67
+ opt['path']['log'] = experiments_root
68
+ opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
69
+
70
+ # change some options for debug mode
71
+ if 'debug' in opt['name']:
72
+ opt['train']['val_freq'] = 8
73
+ opt['logger']['print_freq'] = 1
74
+ opt['logger']['save_checkpoint_freq'] = 8
75
+ else: # test
76
+ if not opt['path'].get('results_root', None):
77
+ results_root = osp.join(opt['path']['root'], 'results', opt['name'])
78
+ opt['path']['results_root'] = results_root
79
+ opt['path']['log'] = opt['path']['results_root']
80
+
81
+ # network
82
+ if opt['distortion'] == 'sr':
83
+ opt['network_G']['scale'] = scale
84
+
85
+ # relative learning rate
86
+ if 'train' in opt:
87
+ niter = opt['train']['niter']
88
+ if 'T_period_rel' in opt['train']:
89
+ opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
90
+ if 'restarts_rel' in opt['train']:
91
+ opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
92
+ if 'lr_steps_rel' in opt['train']:
93
+ opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
94
+ if 'lr_steps_inverse_rel' in opt['train']:
95
+ opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
96
+ print(opt['train'])
97
+
98
+ return opt
99
+
100
+
101
+ def dict2str(opt, indent_l=1):
102
+ '''dict to string for logger'''
103
+ msg = ''
104
+ for k, v in opt.items():
105
+ if isinstance(v, dict):
106
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
107
+ msg += dict2str(v, indent_l + 1)
108
+ msg += ' ' * (indent_l * 2) + ']\n'
109
+ else:
110
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
111
+ return msg
112
+
113
+
114
+ class NoneDict(dict):
115
+ def __missing__(self, key):
116
+ return None
117
+
118
+
119
+ # convert to NoneDict, which return None for missing key.
120
+ def dict_to_nonedict(opt):
121
+ if isinstance(opt, dict):
122
+ new_opt = dict()
123
+ for key, sub_opt in opt.items():
124
+ new_opt[key] = dict_to_nonedict(sub_opt)
125
+ return NoneDict(**new_opt)
126
+ elif isinstance(opt, list):
127
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
128
+ else:
129
+ return opt
130
+
131
+
132
+ def check_resume(opt, resume_iter):
133
+ '''Check resume states and pretrain_model paths'''
134
+ logger = logging.getLogger('base')
135
+ if opt['path']['resume_state']:
136
+ if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
137
+ 'pretrain_model_D', None) is not None:
138
+ logger.warning('pretrain_model path will be ignored when resuming training.')
139
+
140
+ opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
141
+ '{}_G.pth'.format(resume_iter))
142
+ logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
143
+ if 'gan' in opt['model']:
144
+ opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
145
+ '{}_D.pth'.format(resume_iter))
146
+ logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
models/SRFlow/code/prepare_data.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ import os
17
+ import sys
18
+
19
+ import numpy as np
20
+ import random
21
+ import imageio
22
+ import pickle
23
+
24
+ from natsort import natsort
25
+ from tqdm import tqdm
26
+
27
+ def get_img_paths(dir_path, wildcard='*.png'):
28
+ return natsort.natsorted(glob.glob(dir_path + '/' + wildcard))
29
+
30
+ def create_all_dirs(path):
31
+ if "." in path.split("/")[-1]:
32
+ dirs = os.path.dirname(path)
33
+ else:
34
+ dirs = path
35
+ os.makedirs(dirs, exist_ok=True)
36
+
37
+ def to_pklv4(obj, path, vebose=False):
38
+ create_all_dirs(path)
39
+ with open(path, 'wb') as f:
40
+ pickle.dump(obj, f, protocol=4)
41
+ if vebose:
42
+ print("Wrote {}".format(path))
43
+
44
+
45
+ from imresize import imresize
46
+
47
+ def random_crop(img, size):
48
+ h, w, c = img.shape
49
+
50
+ h_start = np.random.randint(0, h - size)
51
+ h_end = h_start + size
52
+
53
+ w_start = np.random.randint(0, w - size)
54
+ w_end = w_start + size
55
+
56
+ return img[h_start:h_end, w_start:w_end]
57
+
58
+
59
+ def imread(img_path):
60
+ img = imageio.imread(img_path)
61
+ if len(img.shape) == 2:
62
+ img = np.stack([img, ] * 3, axis=2)
63
+ return img
64
+
65
+
66
+ def to_pklv4_1pct(obj, path, vebose):
67
+ n = int(round(len(obj) * 0.01))
68
+ path = path.replace(".", "_1pct.")
69
+ to_pklv4(obj[:n], path, vebose=True)
70
+
71
+
72
+ def main(dir_path):
73
+ hrs = []
74
+ lqs = []
75
+
76
+ img_paths = get_img_paths(dir_path)
77
+ for img_path in tqdm(img_paths):
78
+ img = imread(img_path)
79
+
80
+ for i in range(47):
81
+ crop = random_crop(img, 256)
82
+ cropX4 = imresize(crop, scalar_scale=0.25)
83
+ hrs.append(crop)
84
+ lqs.append(cropX4)
85
+
86
+ shuffle_combined(hrs, lqs)
87
+
88
+ hrs_path = get_hrs_path(dir_path)
89
+ to_pklv4(hrs, hrs_path, vebose=True)
90
+
91
+ lqs_path = get_lqs_path(dir_path)
92
+ to_pklv4(lqs, lqs_path, vebose=True)
93
+
94
+
95
+ def get_hrs_path(dir_path):
96
+ base_dir = '/kaggle/working/'
97
+ name = os.path.basename(dir_path)
98
+ hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4')
99
+ return hrs_path
100
+
101
+
102
+ def get_lqs_path(dir_path):
103
+ base_dir = '/kaggle/working/'
104
+ name = os.path.basename(dir_path)
105
+ hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4')
106
+ return hrs_path
107
+
108
+
109
+ def shuffle_combined(hrs, lqs):
110
+ combined = list(zip(hrs, lqs))
111
+ random.shuffle(combined)
112
+ hrs[:], lqs[:] = zip(*combined)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ dir_path = sys.argv[1]
117
+ assert os.path.isdir(dir_path)
118
+ main(dir_path)
models/SRFlow/code/test.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+
18
+ import glob
19
+ import sys
20
+ from collections import OrderedDict
21
+
22
+ from natsort import natsort
23
+
24
+ import options.options as option
25
+ from Measure import Measure, psnr
26
+ from imresize import imresize
27
+ from models import create_model
28
+ import torch
29
+ from utils.util import opt_get
30
+ import numpy as np
31
+ import pandas as pd
32
+ import os
33
+ import cv2
34
+
35
+
36
+ def fiFindByWildcard(wildcard):
37
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
38
+
39
+
40
+ def load_model(conf_path):
41
+ opt = option.parse(conf_path, is_train=False)
42
+ opt['gpu_ids'] = None
43
+ opt = option.dict_to_nonedict(opt)
44
+ model = create_model(opt)
45
+
46
+ model_path = opt_get(opt, ['model_path'], None)
47
+ model.load_network(load_path=model_path, network=model.netG)
48
+ return model, opt
49
+
50
+
51
+ def predict(model, lr):
52
+ model.feed_data({"LQ": t(lr)}, need_GT=False)
53
+ model.test()
54
+ visuals = model.get_current_visuals(need_GT=False)
55
+ return visuals.get('rlt', visuals.get("SR"))
56
+
57
+
58
+ def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
59
+
60
+
61
+ def rgb(t): return (
62
+ np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
63
+ np.uint8)
64
+
65
+
66
+ def imread(path):
67
+ return cv2.imread(path)[:, :, [2, 1, 0]]
68
+
69
+
70
+ def imwrite(path, img):
71
+ os.makedirs(os.path.dirname(path), exist_ok=True)
72
+ cv2.imwrite(path, img[:, :, [2, 1, 0]])
73
+
74
+
75
+ def imCropCenter(img, size):
76
+ h, w, c = img.shape
77
+
78
+ h_start = max(h // 2 - size // 2, 0)
79
+ h_end = min(h_start + size, h)
80
+
81
+ w_start = max(w // 2 - size // 2, 0)
82
+ w_end = min(w_start + size, w)
83
+
84
+ return img[h_start:h_end, w_start:w_end]
85
+
86
+
87
+ def impad(img, top=0, bottom=0, left=0, right=0, color=255):
88
+ return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
89
+
90
+
91
+ def main():
92
+ conf_path = sys.argv[1]
93
+ conf = conf_path.split('/')[-1].replace('.yml', '')
94
+ model, opt = load_model(conf_path)
95
+
96
+ data_dir = opt['dataroot']
97
+
98
+ # this_dir = os.path.dirname(os.path.realpath(__file__))
99
+ test_dir = os.path.join('/kaggle/working/', 'results', conf)
100
+ print(f"Out dir: {test_dir}")
101
+
102
+ measure = Measure(use_gpu=False)
103
+
104
+ fname = f'measure_full.csv'
105
+ fname_tmp = fname + "_"
106
+ path_out_measures = os.path.join(test_dir, fname_tmp)
107
+ path_out_measures_final = os.path.join(test_dir, fname)
108
+
109
+ if os.path.isfile(path_out_measures_final):
110
+ df = pd.read_csv(path_out_measures_final)
111
+ elif os.path.isfile(path_out_measures):
112
+ df = pd.read_csv(path_out_measures)
113
+ else:
114
+ df = None
115
+
116
+ scale = opt['scale']
117
+
118
+ pad_factor = 2
119
+
120
+ data_sets = [
121
+ 'Set5',
122
+ 'Set14',
123
+ 'Urban100',
124
+ 'BSD100'
125
+ ]
126
+
127
+ final_df = pd.DataFrame()
128
+
129
+ for data_set in data_sets:
130
+ lr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*LR.png'))
131
+ hr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*HR.png'))
132
+
133
+ df = pd.DataFrame(columns=['conf', 'heat', 'data_set', 'name', 'PSNR', 'SSIM', 'LPIPS', 'LRC PSNR'])
134
+
135
+ for lr_path, hr_path, idx_test in zip(lr_paths, hr_paths, range(len(lr_paths))):
136
+ with torch.no_grad(), torch.cuda.amp.autocast():
137
+ lr = imread(lr_path)
138
+ hr = imread(hr_path)
139
+
140
+ # Pad image to be % 2
141
+ h, w, c = lr.shape
142
+ lq_orig = lr.copy()
143
+ lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
144
+ right=int(np.ceil(w / pad_factor) * pad_factor - w))
145
+
146
+ lr_t = t(lr)
147
+
148
+ heat = opt['heat']
149
+
150
+ if df is not None and len(df[(df['heat'] == heat) & (df['name'] == idx_test)]) == 1:
151
+ continue
152
+
153
+ sr_t = model.get_sr(lq=lr_t, heat=heat)
154
+
155
+ sr = rgb(torch.clamp(sr_t, 0, 1))
156
+ sr = sr[:h * scale, :w * scale]
157
+
158
+ path_out_sr = os.path.join(test_dir, data_set, "{:0.2f}".format(heat).replace('.', ''), "{:06d}.png".format(idx_test))
159
+ imwrite(path_out_sr, sr)
160
+
161
+ meas = OrderedDict(conf=conf, heat=heat, data_set=data_set, name=idx_test)
162
+ meas['PSNR'], meas['SSIM'], meas['LPIPS'] = measure.measure(sr, hr)
163
+
164
+ lr_reconstruct_rgb = imresize(sr, 1 / opt['scale'])
165
+ meas['LRC PSNR'] = psnr(lq_orig, lr_reconstruct_rgb)
166
+
167
+ str_out = format_measurements(meas)
168
+ print(str_out)
169
+
170
+ df = df._append(pd.DataFrame([meas]), ignore_index=True)
171
+
172
+ final_df = pd.concat([final_df, df])
173
+
174
+ final_df.to_csv(path_out_measures, index=False)
175
+ os.rename(path_out_measures, path_out_measures_final)
176
+
177
+ # str_out = format_measurements(df.mean())
178
+ # print(f"Results in: {path_out_measures_final}")
179
+ # print('Mean: ' + str_out)
180
+
181
+
182
+ def format_measurements(meas):
183
+ s_out = []
184
+ for k, v in meas.items():
185
+ v = f"{v:0.2f}" if isinstance(v, float) else v
186
+ s_out.append(f"{k}: {v}")
187
+ str_out = ", ".join(s_out)
188
+ return str_out
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
models/SRFlow/code/train.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ from os.path import basename
19
+ import math
20
+ import argparse
21
+ import random
22
+ import logging
23
+ import cv2
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.multiprocessing as mp
28
+
29
+ import options.options as option
30
+ from utils import util
31
+ from data import create_dataloader, create_dataset
32
+ from models import create_model
33
+ from utils.timer import Timer, TickTock
34
+ from utils.util import get_resume_paths
35
+
36
+ import wandb
37
+
38
+ def getEnv(name): import os; return True if name in os.environ.keys() else False
39
+
40
+
41
+ def init_dist(backend='nccl', **kwargs):
42
+ ''' initialization for distributed training'''
43
+ # if mp.get_start_method(allow_none=True) is None:
44
+ if mp.get_start_method(allow_none=True) != 'spawn':
45
+ mp.set_start_method('spawn')
46
+ rank = int(os.environ['RANK'])
47
+ num_gpus = torch.cuda.device_count()
48
+ torch.cuda.set_deviceDistIterSampler(rank % num_gpus)
49
+ dist.init_process_group(backend=backend, **kwargs)
50
+
51
+
52
+ def main():
53
+ wandb.init(project='srflow')
54
+ #### options
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
57
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
58
+ help='job launcher')
59
+ parser.add_argument('--local_rank', type=int, default=0)
60
+ args = parser.parse_args()
61
+ opt = option.parse(args.opt, is_train=True)
62
+
63
+ #### distributed training settings
64
+ opt['dist'] = False
65
+ rank = -1
66
+ print('Disabled distributed training.')
67
+
68
+ #### loading resume state if exists
69
+ if opt['path'].get('resume_state', None):
70
+ resume_state_path, _ = get_resume_paths(opt)
71
+
72
+ # distributed resuming: all load into default GPU
73
+ if resume_state_path is None:
74
+ resume_state = None
75
+ else:
76
+ device_id = torch.cuda.current_device()
77
+ resume_state = torch.load(resume_state_path,
78
+ map_location=lambda storage, loc: storage.cuda(device_id))
79
+ option.check_resume(opt, resume_state['iter']) # check resume options
80
+ else:
81
+ resume_state = None
82
+
83
+ #### mkdir and loggers
84
+ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
85
+ if resume_state is None:
86
+ util.mkdir_and_rename(
87
+ opt['path']['experiments_root']) # rename experiment folder if exists
88
+ util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
89
+ and 'pretrain_model' not in key and 'resume' not in key))
90
+
91
+ # config loggers. Before it, the log will not work
92
+ util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
93
+ screen=True, tofile=True)
94
+ util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
95
+ screen=True, tofile=True)
96
+ logger = logging.getLogger('base')
97
+ logger.info(option.dict2str(opt))
98
+
99
+ # tensorboard logger
100
+ if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
101
+ version = float(torch.__version__[0:3])
102
+ if version >= 1.1: # PyTorch 1.1
103
+ from torch.utils.tensorboard import SummaryWriter
104
+ else:
105
+ logger.info(
106
+ 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
107
+ from tensorboardX import SummaryWriter
108
+ conf_name = basename(args.opt).replace(".yml", "")
109
+ exp_dir = opt['path']['experiments_root']
110
+ log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
111
+ log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
112
+ tb_logger_train = SummaryWriter(log_dir=log_dir_train)
113
+ tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
114
+ else:
115
+ util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
116
+ logger = logging.getLogger('base')
117
+
118
+ # convert to NoneDict, which returns None for missing keys
119
+ opt = option.dict_to_nonedict(opt)
120
+
121
+ #### random seed
122
+ seed = opt['train']['manual_seed']
123
+ if seed is None:
124
+ seed = random.randint(1, 10000)
125
+ if rank <= 0:
126
+ logger.info('Random seed: {}'.format(seed))
127
+ util.set_random_seed(seed)
128
+
129
+ torch.backends.cudnn.benchmark = True
130
+ # torch.backends.cudnn.deterministic = True
131
+
132
+ #### create train and val dataloader
133
+ dataset_ratio = 200 # enlarge the size of each epoch
134
+ for phase, dataset_opt in opt['datasets'].items():
135
+ if phase == 'train':
136
+ full_dataset = create_dataset(dataset_opt)
137
+ print('Dataset created')
138
+ train_len = int(len(full_dataset) * 0.95)
139
+ val_len = len(full_dataset) - train_len
140
+ train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len])
141
+ train_size = int(math.ceil(train_len / dataset_opt['batch_size']))
142
+ total_iters = int(opt['train']['niter'])
143
+ total_epochs = int(math.ceil(total_iters / train_size))
144
+ train_sampler = None
145
+ train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
146
+ if rank <= 0:
147
+ logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
148
+ len(train_set), train_size))
149
+ logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
150
+ total_epochs, total_iters))
151
+ val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1,
152
+ pin_memory=True)
153
+ elif phase == 'val':
154
+ continue
155
+ else:
156
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
157
+ assert train_loader is not None
158
+
159
+ #### create model
160
+ current_step = 0 if resume_state is None else resume_state['iter']
161
+ model = create_model(opt, current_step)
162
+
163
+ #### resume training
164
+ if resume_state:
165
+ logger.info('Resuming training from epoch: {}, iter: {}.'.format(
166
+ resume_state['epoch'], resume_state['iter']))
167
+
168
+ start_epoch = resume_state['epoch']
169
+ current_step = resume_state['iter']
170
+ model.resume_training(resume_state) # handle optimizers and schedulers
171
+ else:
172
+ current_step = 0
173
+ start_epoch = 0
174
+
175
+ #### training
176
+ timer = Timer()
177
+ logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
178
+ timerData = TickTock()
179
+
180
+ for epoch in range(start_epoch, total_epochs + 1):
181
+ if opt['dist']:
182
+ train_sampler.set_epoch(epoch)
183
+
184
+ timerData.tick()
185
+ for _, train_data in enumerate(train_loader):
186
+ timerData.tock()
187
+ current_step += 1
188
+ if current_step > total_iters:
189
+ break
190
+
191
+ #### training
192
+ model.feed_data(train_data)
193
+
194
+ #### update learning rate
195
+ model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
196
+
197
+ try:
198
+ nll = model.optimize_parameters(current_step)
199
+ except RuntimeError as e:
200
+ print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
201
+ print(e)
202
+
203
+ if nll is None:
204
+ nll = 0
205
+
206
+ wandb.log({"loss": nll})
207
+ #### log
208
+ def eta(t_iter):
209
+ return (t_iter * (opt['train']['niter'] - current_step)) / 3600
210
+
211
+ if current_step % opt['logger']['print_freq'] == 0 \
212
+ or current_step - (resume_state['iter'] if resume_state else 0) < 25:
213
+ avg_time = timer.get_average_and_reset()
214
+ avg_data_time = timerData.get_average_and_reset()
215
+ message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
216
+ epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
217
+ eta(avg_time), nll)
218
+ print(message)
219
+ timer.tick()
220
+ # Reduce number of logs
221
+ if current_step % 5 == 0:
222
+ tb_logger_train.add_scalar('loss/nll', nll, current_step)
223
+ tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
224
+ tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
225
+ tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
226
+ tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
227
+ for k, v in model.get_current_log().items():
228
+ tb_logger_train.add_scalar(k, v, current_step)
229
+
230
+ # validation
231
+ if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
232
+ avg_psnr = 0.0
233
+ idx = 0
234
+ nlls = []
235
+ for val_data in val_loader:
236
+ idx += 1
237
+ img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
238
+ img_dir = os.path.join(opt['path']['val_images'], img_name)
239
+ util.mkdir(img_dir)
240
+
241
+ model.feed_data(val_data)
242
+
243
+ nll = model.test()
244
+ if nll is None:
245
+ nll = 0
246
+ nlls.append(nll)
247
+
248
+ visuals = model.get_current_visuals()
249
+
250
+ sr_img = None
251
+ # Save SR images for reference
252
+ if hasattr(model, 'heats'):
253
+ for heat in model.heats:
254
+ for i in range(model.n_sample):
255
+ sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8
256
+ save_img_path = os.path.join(img_dir,
257
+ '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
258
+ current_step,
259
+ int(heat * 100), i))
260
+ util.save_img(sr_img, save_img_path)
261
+ else:
262
+ sr_img = util.tensor2img(visuals['SR']) # uint8
263
+ save_img_path = os.path.join(img_dir,
264
+ '{:s}_{:d}.png'.format(img_name, current_step))
265
+ util.save_img(sr_img, save_img_path)
266
+ assert sr_img is not None
267
+
268
+ # Save LQ images for reference
269
+ save_img_path_lq = os.path.join(img_dir,
270
+ '{:s}_LQ.png'.format(img_name))
271
+ if not os.path.isfile(save_img_path_lq):
272
+ lq_img = util.tensor2img(visuals['LQ']) # uint8
273
+ util.save_img(
274
+ cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
275
+ interpolation=cv2.INTER_NEAREST),
276
+ save_img_path_lq)
277
+
278
+ # Save GT images for reference
279
+ gt_img = util.tensor2img(visuals['GT']) # uint8
280
+ save_img_path_gt = os.path.join(img_dir,
281
+ '{:s}_GT.png'.format(img_name))
282
+ if not os.path.isfile(save_img_path_gt):
283
+ util.save_img(gt_img, save_img_path_gt)
284
+
285
+ # calculate PSNR
286
+ crop_size = opt['scale']
287
+ gt_img = gt_img / 255.
288
+ sr_img = sr_img / 255.
289
+ cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
290
+ cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
291
+ avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
292
+
293
+ avg_psnr = avg_psnr / idx
294
+ avg_nll = sum(nlls) / len(nlls)
295
+
296
+ # log
297
+ logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
298
+ logger_val = logging.getLogger('val') # validation logger
299
+ logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
300
+ epoch, current_step, avg_psnr))
301
+
302
+ # tensorboard logger
303
+ tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
304
+ tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)
305
+
306
+ tb_logger_train.flush()
307
+ tb_logger_valid.flush()
308
+
309
+ #### save models and training states
310
+ if current_step % opt['logger']['save_checkpoint_freq'] == 0:
311
+ if rank <= 0:
312
+ logger.info('Saving models and training states.')
313
+ model.save(current_step)
314
+ model.save_training_state(epoch, current_step)
315
+
316
+ timerData.tick()
317
+
318
+ with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
319
+ f.write("TRAIN_DONE")
320
+
321
+ if rank <= 0:
322
+ logger.info('Saving the final model.')
323
+ model.save('latest')
324
+ logger.info('End of training.')
325
+
326
+
327
+ if __name__ == '__main__':
328
+ main()
models/SRFlow/code/utils/__init__.py ADDED
File without changes
models/SRFlow/code/utils/timer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import time
18
+
19
+
20
+ class ScopeTimer:
21
+ def __init__(self, name):
22
+ self.name = name
23
+
24
+ def __enter__(self):
25
+ self.start = time.time()
26
+ return self
27
+
28
+ def __exit__(self, *args):
29
+ self.end = time.time()
30
+ self.interval = self.end - self.start
31
+ print("{} {:.3E}".format(self.name, self.interval))
32
+
33
+
34
+ class Timer:
35
+ def __init__(self):
36
+ self.times = []
37
+
38
+ def tick(self):
39
+ self.times.append(time.time())
40
+
41
+ def get_average_and_reset(self):
42
+ if len(self.times) < 2:
43
+ return -1
44
+ avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1)
45
+ self.times = [self.times[-1]]
46
+ return avg
47
+
48
+ def get_last_iteration(self):
49
+ if len(self.times) < 2:
50
+ return 0
51
+ return self.times[-1] - self.times[-2]
52
+
53
+
54
+ class TickTock:
55
+ def __init__(self):
56
+ self.time_pairs = []
57
+ self.current_time = None
58
+
59
+ def tick(self):
60
+ self.current_time = time.time()
61
+
62
+ def tock(self):
63
+ assert self.current_time is not None, self.current_time
64
+ self.time_pairs.append([self.current_time, time.time()])
65
+ self.current_time = None
66
+
67
+ def get_average_and_reset(self):
68
+ if len(self.time_pairs) == 0:
69
+ return -1
70
+ deltas = [t2 - t1 for t1, t2 in self.time_pairs]
71
+ avg = sum(deltas) / len(deltas)
72
+ self.time_pairs = []
73
+ return avg
74
+
75
+ def get_last_iteration(self):
76
+ if len(self.time_pairs) == 0:
77
+ return -1
78
+ return self.time_pairs[-1][1] - self.time_pairs[-1][0]
models/SRFlow/code/utils/util.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+ import time
5
+ import math
6
+ from datetime import datetime
7
+ import random
8
+ import logging
9
+ from collections import OrderedDict
10
+
11
+ import natsort
12
+ import numpy as np
13
+ import cv2
14
+ import torch
15
+ from torchvision.utils import make_grid
16
+ from shutil import get_terminal_size
17
+
18
+ import yaml
19
+
20
+ try:
21
+ from yaml import CLoader as Loader, CDumper as Dumper
22
+ except ImportError:
23
+ from yaml import Loader, Dumper
24
+
25
+
26
+ def OrderedYaml():
27
+ '''yaml orderedDict support'''
28
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
29
+
30
+ def dict_representer(dumper, data):
31
+ return dumper.represent_dict(data.items())
32
+
33
+ def dict_constructor(loader, node):
34
+ return OrderedDict(loader.construct_pairs(node))
35
+
36
+ Dumper.add_representer(OrderedDict, dict_representer)
37
+ Loader.add_constructor(_mapping_tag, dict_constructor)
38
+ return Loader, Dumper
39
+
40
+
41
+ ####################
42
+ # miscellaneous
43
+ ####################
44
+
45
+
46
+ def get_timestamp():
47
+ return datetime.now().strftime('%y%m%d-%H%M%S')
48
+
49
+
50
+ def mkdir(path):
51
+ if not os.path.exists(path):
52
+ os.makedirs(path)
53
+
54
+
55
+ def mkdirs(paths):
56
+ if isinstance(paths, str):
57
+ mkdir(paths)
58
+ else:
59
+ for path in paths:
60
+ mkdir(path)
61
+
62
+
63
+ def mkdir_and_rename(path):
64
+ if os.path.exists(path):
65
+ new_name = path + '_archived_' + get_timestamp()
66
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
67
+ logger = logging.getLogger('base')
68
+ logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
69
+ os.rename(path, new_name)
70
+ os.makedirs(path)
71
+
72
+
73
+ def set_random_seed(seed):
74
+ random.seed(seed)
75
+ np.random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ torch.cuda.manual_seed_all(seed)
78
+
79
+
80
+ def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
81
+ '''set up logger'''
82
+ lg = logging.getLogger(logger_name)
83
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
84
+ datefmt='%y-%m-%d %H:%M:%S')
85
+ lg.setLevel(level)
86
+ if tofile:
87
+ log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
88
+ fh = logging.FileHandler(log_file, mode='w')
89
+ fh.setFormatter(formatter)
90
+ lg.addHandler(fh)
91
+ if screen:
92
+ sh = logging.StreamHandler()
93
+ sh.setFormatter(formatter)
94
+ lg.addHandler(sh)
95
+
96
+
97
+ ####################
98
+ # image convert
99
+ ####################
100
+
101
+
102
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
103
+ '''
104
+ Converts a torch Tensor into an image Numpy array
105
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
106
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
107
+ '''
108
+ if hasattr(tensor, 'detach'):
109
+ tensor = tensor.detach()
110
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
111
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
112
+ n_dim = tensor.dim()
113
+ if n_dim == 4:
114
+ n_img = len(tensor)
115
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
116
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
117
+ elif n_dim == 3:
118
+ img_np = tensor.numpy()
119
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
120
+ elif n_dim == 2:
121
+ img_np = tensor.numpy()
122
+ else:
123
+ raise TypeError(
124
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
125
+ if out_type == np.uint8:
126
+ img_np = (img_np * 255.0).round()
127
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
128
+ return img_np.astype(out_type)
129
+
130
+
131
+ def save_img(img, img_path, mode='RGB'):
132
+ cv2.imwrite(img_path, img)
133
+
134
+
135
+ ####################
136
+ # metric
137
+ ####################
138
+
139
+
140
+ def calculate_psnr(img1, img2):
141
+ # img1 and img2 have range [0, 255]
142
+ img1 = img1.astype(np.float64)
143
+ img2 = img2.astype(np.float64)
144
+ mse = np.mean((img1 - img2) ** 2)
145
+ if mse == 0:
146
+ return float('inf')
147
+ return 20 * math.log10(255.0 / math.sqrt(mse))
148
+
149
+
150
+ def get_resume_paths(opt):
151
+ resume_state_path = None
152
+ resume_model_path = None
153
+ ts = opt_get(opt, ['path', 'training_state'])
154
+ if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
155
+ wildcard = os.path.join(ts, "*")
156
+ paths = natsort.natsorted(glob.glob(wildcard))
157
+ if len(paths) > 0:
158
+ resume_state_path = paths[-1]
159
+ resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth')
160
+ else:
161
+ resume_state_path = opt.get('path', {}).get('resume_state')
162
+ return resume_state_path, resume_model_path
163
+
164
+
165
+ def opt_get(opt, keys, default=None):
166
+ if opt is None:
167
+ return default
168
+ ret = opt
169
+ for k in keys:
170
+ ret = ret.get(k, None)
171
+ if ret is None:
172
+ return default
173
+ return ret
174
+
models/SRFlow/srflow.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from code.test import imread, impad, t, load_model, rgb
5
+
6
+ def return_SRFlow_result(lr_path, conf_path='/models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
7
+ model, opt = load_model(conf_path)
8
+ lr = imread(lr_path)
9
+
10
+ scale = opt['scale']
11
+ pad_factor = 2
12
+
13
+ h, w, c = lr.shape
14
+ lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
15
+ right=int(np.ceil(w / pad_factor) * pad_factor - w))
16
+
17
+ lr_t = t(lr)
18
+ heat = opt[heat]
19
+
20
+ sr_t = model.get_sr(lq=lr_t, heat=heat)
21
+
22
+ sr = rgb(torch.clamp(sr_t, 0, 1))
23
+ sr = sr[:h * scale, :w * scale]
24
+
25
+ return sr
26
+
27
+